diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 3f53811f85..94e857f93a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -36,7 +36,6 @@ /api/core/workflow/graph/ @laipz8200 @QuantumGhost /api/core/workflow/graph_events/ @laipz8200 @QuantumGhost /api/core/workflow/node_events/ @laipz8200 @QuantumGhost -/api/graphon/model_runtime/ @laipz8200 @WH-2099 # Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM) /api/core/workflow/nodes/agent/ @Nov1c444 diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 6b87946221..7bce056970 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -14,18 +14,17 @@ concurrency: cancel-in-progress: true jobs: - test: - name: API Tests + api-unit: + name: API Unit Tests runs-on: ubuntu-latest env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + COVERAGE_FILE: coverage-unit defaults: run: shell: bash strategy: matrix: python-version: - - "3.11" - "3.12" steps: @@ -51,6 +50,52 @@ jobs: - name: Run dify config tests run: uv run --project api dev/pytest/pytest_config_tests.py + - name: Run Unit Tests + run: uv run --project api bash dev/pytest/pytest_unit_tests.sh + + - name: Upload unit coverage data + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: api-coverage-unit + path: coverage-unit + retention-days: 1 + + api-integration: + name: API Integration Tests + runs-on: ubuntu-latest + env: + COVERAGE_FILE: coverage-integration + STORAGE_TYPE: opendal + OPENDAL_SCHEME: fs + OPENDAL_FS_ROOT: /tmp/dify-storage + defaults: + run: + shell: bash + strategy: + matrix: + python-version: + - "3.12" + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Setup UV and Python + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + with: + enable-cache: true + python-version: ${{ matrix.python-version }} + cache-dependency-glob: api/uv.lock + + - name: Check UV lockfile + run: uv lock --project api --check + + - name: Install dependencies + run: uv sync --project api --dev + - name: Set up dotenvs run: | cp docker/.env.example docker/.env @@ -74,22 +119,90 @@ jobs: run: | cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env - - name: Run API Tests - env: - STORAGE_TYPE: opendal - OPENDAL_SCHEME: fs - OPENDAL_FS_ROOT: /tmp/dify-storage + - name: Run Integration Tests run: | uv run --project api pytest \ -n auto \ --timeout "${PYTEST_TIMEOUT:-180}" \ api/tests/integration_tests/workflow \ api/tests/integration_tests/tools \ - api/tests/test_containers_integration_tests \ - api/tests/unit_tests + api/tests/test_containers_integration_tests + + - name: Upload integration coverage data + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: api-coverage-integration + path: coverage-integration + retention-days: 1 + + api-coverage: + name: API Coverage + runs-on: ubuntu-latest + needs: + - api-unit + - api-integration + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + COVERAGE_FILE: .coverage + defaults: + run: + shell: bash + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Setup UV and Python + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + with: + enable-cache: true + python-version: "3.12" + cache-dependency-glob: api/uv.lock + + - name: Install dependencies + run: uv sync --project api --dev + + - name: Download coverage data + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 + with: + path: coverage-data + pattern: api-coverage-* + merge-multiple: true + + - name: Combine coverage + run: | + set -euo pipefail + + echo "### API Coverage" >> "$GITHUB_STEP_SUMMARY" + echo "" >> "$GITHUB_STEP_SUMMARY" + echo "Merged backend coverage report generated for Codecov project status." >> "$GITHUB_STEP_SUMMARY" + echo "" >> "$GITHUB_STEP_SUMMARY" + + unit_coverage="$(find coverage-data -type f -name coverage-unit -print -quit)" + integration_coverage="$(find coverage-data -type f -name coverage-integration -print -quit)" + : "${unit_coverage:?coverage-unit artifact not found}" + : "${integration_coverage:?coverage-integration artifact not found}" + + report_file="$(mktemp)" + uv run --project api coverage combine "$unit_coverage" "$integration_coverage" + uv run --project api coverage report --show-missing | tee "$report_file" + echo "Summary: \`$(tail -n 1 "$report_file")\`" >> "$GITHUB_STEP_SUMMARY" + { + echo "" + echo "
Coverage report" + echo "" + echo '```' + cat "$report_file" + echo '```' + echo "
" + } >> "$GITHUB_STEP_SUMMARY" + uv run --project api coverage xml -o coverage.xml - name: Report coverage - if: ${{ env.CODECOV_TOKEN != '' && matrix.python-version == '3.12' }} + if: ${{ env.CODECOV_TOKEN != '' }} uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3 with: files: ./coverage.xml diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index be6186980e..d8a53c9594 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -2,6 +2,9 @@ name: autofix.ci on: pull_request: branches: ["main"] + merge_group: + branches: ["main"] + types: [checks_requested] push: branches: ["main"] permissions: @@ -12,9 +15,15 @@ jobs: if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - name: Complete merge group check + if: github.event_name == 'merge_group' + run: echo "autofix.ci updates pull request branches, not merge group refs." + + - if: github.event_name != 'merge_group' + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Check Docker Compose inputs + if: github.event_name != 'merge_group' id: docker-compose-changes uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: @@ -24,30 +33,34 @@ jobs: docker/docker-compose-template.yaml docker/docker-compose.yaml - name: Check web inputs + if: github.event_name != 'merge_group' id: web-changes uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | web/** - name: Check api inputs + if: github.event_name != 'merge_group' id: api-changes uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | api/** - - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + - if: github.event_name != 'merge_group' + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 with: python-version: "3.11" - - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + - if: github.event_name != 'merge_group' + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 - name: Generate Docker Compose - if: steps.docker-compose-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true' run: | cd docker ./generate_docker_compose - - if: steps.api-changes.outputs.any_changed == 'true' + - if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true' run: | cd api uv sync --dev @@ -59,13 +72,13 @@ jobs: uv run ruff format .. - name: count migration progress - if: steps.api-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true' run: | cd api ./cnt_base.sh - name: ast-grep - if: steps.api-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true' run: | # ast-grep exits 1 if no matches are found; allow idempotent runs. uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true @@ -95,13 +108,14 @@ jobs: find . -name "*.py.bak" -type f -delete - name: Setup web environment - if: steps.web-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true' uses: ./.github/actions/setup-web - name: ESLint autofix - if: steps.web-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true' run: | cd web vp exec eslint --concurrency=2 --prune-suppressions --quiet || true - - uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3 + - if: github.event_name != 'merge_group' + uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3 diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml index 69023c24cc..2d96dae4da 100644 --- a/.github/workflows/main-ci.yml +++ b/.github/workflows/main-ci.yml @@ -3,10 +3,14 @@ name: Main CI Pipeline on: pull_request: branches: ["main"] + merge_group: + branches: ["main"] + types: [checks_requested] push: branches: ["main"] permissions: + actions: write contents: write pull-requests: write checks: write @@ -17,12 +21,28 @@ concurrency: cancel-in-progress: true jobs: + pre_job: + name: Skip Duplicate Checks + runs-on: ubuntu-latest + outputs: + should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }} + steps: + - id: skip_check + continue-on-error: true + uses: fkirc/skip-duplicate-actions@f75f66ce1886f00957d99748a42c724f4330bdcf # v5.3.1 + with: + cancel_others: 'true' + concurrent_skipping: same_content_newer + # Check which paths were changed to determine which tests to run check-changes: name: Check Changed Files + needs: pre_job + if: needs.pre_job.outputs.should_skip != 'true' runs-on: ubuntu-latest outputs: api-changed: ${{ steps.changes.outputs.api }} + e2e-changed: ${{ steps.changes.outputs.e2e }} web-changed: ${{ steps.changes.outputs.web }} vdb-changed: ${{ steps.changes.outputs.vdb }} migration-changed: ${{ steps.changes.outputs.migration }} @@ -34,49 +54,364 @@ jobs: filters: | api: - 'api/**' - - 'docker/**' - '.github/workflows/api-tests.yml' + - '.github/workflows/expose_service_ports.sh' + - 'docker/.env.example' + - 'docker/middleware.env.example' + - 'docker/docker-compose.middleware.yaml' + - 'docker/docker-compose-template.yaml' + - 'docker/generate_docker_compose' + - 'docker/ssrf_proxy/**' + - 'docker/volumes/sandbox/conf/**' web: - 'web/**' - '.github/workflows/web-tests.yml' - '.github/actions/setup-web/**' + e2e: + - 'api/**' + - 'api/pyproject.toml' + - 'api/uv.lock' + - 'e2e/**' + - 'web/**' + - 'docker/docker-compose.middleware.yaml' + - 'docker/middleware.env.example' + - '.github/workflows/web-e2e.yml' + - '.github/actions/setup-web/**' vdb: - 'api/core/rag/datasource/**' - - 'docker/**' + - 'api/tests/integration_tests/vdb/**' - '.github/workflows/vdb-tests.yml' + - '.github/workflows/expose_service_ports.sh' + - 'docker/.env.example' + - 'docker/middleware.env.example' + - 'docker/docker-compose.yaml' + - 'docker/docker-compose-template.yaml' + - 'docker/generate_docker_compose' + - 'docker/certbot/**' + - 'docker/couchbase-server/**' + - 'docker/elasticsearch/**' + - 'docker/iris/**' + - 'docker/nginx/**' + - 'docker/pgvector/**' + - 'docker/ssrf_proxy/**' + - 'docker/startupscripts/**' + - 'docker/tidb/**' + - 'docker/volumes/**' - 'api/uv.lock' - 'api/pyproject.toml' migration: - 'api/migrations/**' + - 'api/.env.example' - '.github/workflows/db-migration-test.yml' + - '.github/workflows/expose_service_ports.sh' + - 'docker/.env.example' + - 'docker/middleware.env.example' + - 'docker/docker-compose.middleware.yaml' + - 'docker/docker-compose-template.yaml' + - 'docker/generate_docker_compose' + - 'docker/ssrf_proxy/**' + - 'docker/volumes/sandbox/conf/**' - # Run tests in parallel - api-tests: - name: API Tests - needs: check-changes - if: needs.check-changes.outputs.api-changed == 'true' + # Run tests in parallel while always emitting stable required checks. + api-tests-run: + name: Run API Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed == 'true' uses: ./.github/workflows/api-tests.yml secrets: inherit - web-tests: - name: Web Tests - needs: check-changes - if: needs.check-changes.outputs.web-changed == 'true' + api-tests-skip: + name: Skip API Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped API tests + run: echo "No API-related changes detected; skipping API tests." + + api-tests: + name: API Tests + if: ${{ always() }} + needs: + - pre_job + - check-changes + - api-tests-run + - api-tests-skip + runs-on: ubuntu-latest + steps: + - name: Finalize API Tests status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.api-changed }} + RUN_RESULT: ${{ needs.api-tests-run.result }} + SKIP_RESULT: ${{ needs.api-tests-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "API tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "API tests ran successfully." + exit 0 + fi + + echo "API tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "API tests were skipped because no API-related files changed." + exit 0 + fi + + echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 + + web-tests-run: + name: Run Web Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed == 'true' uses: ./.github/workflows/web-tests.yml secrets: inherit + web-tests-skip: + name: Skip Web Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped web tests + run: echo "No web-related changes detected; skipping web tests." + + web-tests: + name: Web Tests + if: ${{ always() }} + needs: + - pre_job + - check-changes + - web-tests-run + - web-tests-skip + runs-on: ubuntu-latest + steps: + - name: Finalize Web Tests status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.web-changed }} + RUN_RESULT: ${{ needs.web-tests-run.result }} + SKIP_RESULT: ${{ needs.web-tests-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "Web tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "Web tests ran successfully." + exit 0 + fi + + echo "Web tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "Web tests were skipped because no web-related files changed." + exit 0 + fi + + echo "Web tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 + + web-e2e-run: + name: Run Web Full-Stack E2E + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed == 'true' + uses: ./.github/workflows/web-e2e.yml + + web-e2e-skip: + name: Skip Web Full-Stack E2E + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped web full-stack e2e + run: echo "No E2E-related changes detected; skipping web full-stack E2E." + + web-e2e: + name: Web Full-Stack E2E + if: ${{ always() }} + needs: + - pre_job + - check-changes + - web-e2e-run + - web-e2e-skip + runs-on: ubuntu-latest + steps: + - name: Finalize Web Full-Stack E2E status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.e2e-changed }} + RUN_RESULT: ${{ needs.web-e2e-run.result }} + SKIP_RESULT: ${{ needs.web-e2e-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "Web full-stack E2E was skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "Web full-stack E2E ran successfully." + exit 0 + fi + + echo "Web full-stack E2E was required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "Web full-stack E2E was skipped because no E2E-related files changed." + exit 0 + fi + + echo "Web full-stack E2E was not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 + style-check: name: Style Check + needs: pre_job + if: needs.pre_job.outputs.should_skip != 'true' uses: ./.github/workflows/style.yml + vdb-tests-run: + name: Run VDB Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed == 'true' + uses: ./.github/workflows/vdb-tests.yml + + vdb-tests-skip: + name: Skip VDB Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped VDB tests + run: echo "No VDB-related changes detected; skipping VDB tests." + vdb-tests: name: VDB Tests - needs: check-changes - if: needs.check-changes.outputs.vdb-changed == 'true' - uses: ./.github/workflows/vdb-tests.yml + if: ${{ always() }} + needs: + - pre_job + - check-changes + - vdb-tests-run + - vdb-tests-skip + runs-on: ubuntu-latest + steps: + - name: Finalize VDB Tests status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.vdb-changed }} + RUN_RESULT: ${{ needs.vdb-tests-run.result }} + SKIP_RESULT: ${{ needs.vdb-tests-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "VDB tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "VDB tests ran successfully." + exit 0 + fi + + echo "VDB tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "VDB tests were skipped because no VDB-related files changed." + exit 0 + fi + + echo "VDB tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 + + db-migration-test-run: + name: Run DB Migration Test + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed == 'true' + uses: ./.github/workflows/db-migration-test.yml + + db-migration-test-skip: + name: Skip DB Migration Test + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped DB migration tests + run: echo "No migration-related changes detected; skipping DB migration tests." db-migration-test: name: DB Migration Test - needs: check-changes - if: needs.check-changes.outputs.migration-changed == 'true' - uses: ./.github/workflows/db-migration-test.yml + if: ${{ always() }} + needs: + - pre_job + - check-changes + - db-migration-test-run + - db-migration-test-skip + runs-on: ubuntu-latest + steps: + - name: Finalize DB Migration Test status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.migration-changed }} + RUN_RESULT: ${{ needs.db-migration-test-run.result }} + SKIP_RESULT: ${{ needs.db-migration-test-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "DB migration tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "DB migration tests ran successfully." + exit 0 + fi + + echo "DB migration tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "DB migration tests were skipped because no migration-related files changed." + exit 0 + fi + + echo "DB migration tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 diff --git a/.github/workflows/semantic-pull-request.yml b/.github/workflows/semantic-pull-request.yml index c21331ec0d..49d2e94695 100644 --- a/.github/workflows/semantic-pull-request.yml +++ b/.github/workflows/semantic-pull-request.yml @@ -7,6 +7,9 @@ on: - edited - reopened - synchronize + merge_group: + branches: ["main"] + types: [checks_requested] jobs: lint: @@ -15,7 +18,11 @@ jobs: pull-requests: read runs-on: ubuntu-latest steps: + - name: Complete merge group check + if: github.event_name == 'merge_group' + run: echo "Semantic PR title validation is handled on pull requests." - name: Check title + if: github.event_name == 'pull_request' uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 23ae36f7b1..7b269ccf4e 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -49,7 +49,7 @@ jobs: - name: Run Type Checks if: steps.changed-files.outputs.any_changed == 'true' - run: make type-check + run: make type-check-core - name: Dotenv check if: steps.changed-files.outputs.any_changed == 'true' diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index 1869254295..aaf51aa606 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -1,26 +1,24 @@ name: Translate i18n Files with Claude Code -# Note: claude-code-action doesn't support push events directly. -# Push events are handled by trigger-i18n-sync.yml which sends repository_dispatch. -# See: https://github.com/langgenius/dify/issues/30743 - on: - repository_dispatch: - types: [i18n-sync] + push: + branches: [main] + paths: + - 'web/i18n/en-US/*.json' workflow_dispatch: inputs: files: - description: 'Specific files to translate (space-separated, e.g., "app common"). Leave empty for all files.' + description: 'Specific files to translate (space-separated, e.g., "app common"). Required for full mode; leave empty in incremental mode to use en-US files changed since HEAD~1.' required: false type: string languages: - description: 'Specific languages to translate (space-separated, e.g., "zh-Hans ja-JP"). Leave empty for all supported languages.' + description: 'Specific languages to translate (space-separated, e.g., "zh-Hans ja-JP"). Leave empty for all supported target languages except en-US.' required: false type: string mode: - description: 'Sync mode: incremental (only changes) or full (re-check all keys)' + description: 'Sync mode: incremental (compare with previous en-US revision) or full (sync all keys in scope)' required: false - default: 'incremental' + default: incremental type: choice options: - incremental @@ -30,11 +28,15 @@ permissions: contents: write pull-requests: write +concurrency: + group: translate-i18n-${{ github.event_name }}-${{ github.ref }} + cancel-in-progress: ${{ github.event_name == 'push' }} + jobs: translate: if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest - timeout-minutes: 60 + timeout-minutes: 120 steps: - name: Checkout repository @@ -51,380 +53,161 @@ jobs: - name: Setup web environment uses: ./.github/actions/setup-web - - name: Detect changed files and generate diff - id: detect_changes + - name: Prepare sync context + id: context + shell: bash run: | - if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then - # Manual trigger - if [ -n "${{ github.event.inputs.files }}" ]; then - echo "CHANGED_FILES=${{ github.event.inputs.files }}" >> $GITHUB_OUTPUT - else - # Get all JSON files in en-US directory - files=$(ls web/i18n/en-US/*.json 2>/dev/null | xargs -n1 basename | sed 's/.json$//' | tr '\n' ' ') - echo "CHANGED_FILES=$files" >> $GITHUB_OUTPUT - fi - echo "TARGET_LANGS=${{ github.event.inputs.languages }}" >> $GITHUB_OUTPUT - echo "SYNC_MODE=${{ github.event.inputs.mode || 'incremental' }}" >> $GITHUB_OUTPUT + DEFAULT_TARGET_LANGS=$(awk " + /value: '/ { + value=\$2 + gsub(/[',]/, \"\", value) + } + /supported: true/ && value != \"en-US\" { + printf \"%s \", value + } + " web/i18n-config/languages.ts | sed 's/[[:space:]]*$//') - # For manual trigger with incremental mode, get diff from last commit - # For full mode, we'll do a complete check anyway - if [ "${{ github.event.inputs.mode }}" == "full" ]; then - echo "Full mode: will check all keys" > /tmp/i18n-diff.txt - echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT + if [ "${{ github.event_name }}" = "push" ]; then + BASE_SHA="${{ github.event.before }}" + if [ -z "$BASE_SHA" ] || [ "$BASE_SHA" = "0000000000000000000000000000000000000000" ]; then + BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true) + fi + HEAD_SHA="${{ github.sha }}" + if [ -n "$BASE_SHA" ]; then + CHANGED_FILES=$(git diff --name-only "$BASE_SHA" "$HEAD_SHA" -- 'web/i18n/en-US/*.json' 2>/dev/null | sed -n 's@^.*/@@p' | sed 's/\.json$//' | tr '\n' ' ' | sed 's/[[:space:]]*$//') else - git diff HEAD~1..HEAD -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt - if [ -s /tmp/i18n-diff.txt ]; then - echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT - else - echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT - fi - fi - elif [ "${{ github.event_name }}" == "repository_dispatch" ]; then - # Triggered by push via trigger-i18n-sync.yml workflow - # Validate required payload fields - if [ -z "${{ github.event.client_payload.changed_files }}" ]; then - echo "Error: repository_dispatch payload missing required 'changed_files' field" >&2 - exit 1 - fi - echo "CHANGED_FILES=${{ github.event.client_payload.changed_files }}" >> $GITHUB_OUTPUT - echo "TARGET_LANGS=" >> $GITHUB_OUTPUT - echo "SYNC_MODE=${{ github.event.client_payload.sync_mode || 'incremental' }}" >> $GITHUB_OUTPUT - - # Decode the base64-encoded diff from the trigger workflow - if [ -n "${{ github.event.client_payload.diff_base64 }}" ]; then - if ! echo "${{ github.event.client_payload.diff_base64 }}" | base64 -d > /tmp/i18n-diff.txt 2>&1; then - echo "Warning: Failed to decode base64 diff payload" >&2 - echo "" > /tmp/i18n-diff.txt - echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT - elif [ -s /tmp/i18n-diff.txt ]; then - echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT - else - echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT - fi - else - echo "" > /tmp/i18n-diff.txt - echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT + CHANGED_FILES=$(find web/i18n/en-US -maxdepth 1 -type f -name '*.json' -print | sed -n 's@^.*/@@p' | sed 's/\.json$//' | sort | tr '\n' ' ' | sed 's/[[:space:]]*$//') fi + TARGET_LANGS="$DEFAULT_TARGET_LANGS" + SYNC_MODE="incremental" else - echo "Unsupported event type: ${{ github.event_name }}" - exit 1 + BASE_SHA="" + HEAD_SHA=$(git rev-parse HEAD) + if [ -n "${{ github.event.inputs.languages }}" ]; then + TARGET_LANGS="${{ github.event.inputs.languages }}" + else + TARGET_LANGS="$DEFAULT_TARGET_LANGS" + fi + SYNC_MODE="${{ github.event.inputs.mode || 'incremental' }}" + if [ -n "${{ github.event.inputs.files }}" ]; then + CHANGED_FILES="${{ github.event.inputs.files }}" + elif [ "$SYNC_MODE" = "incremental" ]; then + BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true) + if [ -n "$BASE_SHA" ]; then + CHANGED_FILES=$(git diff --name-only "$BASE_SHA" "$HEAD_SHA" -- 'web/i18n/en-US/*.json' 2>/dev/null | sed -n 's@^.*/@@p' | sed 's/\.json$//' | tr '\n' ' ' | sed 's/[[:space:]]*$//') + else + CHANGED_FILES=$(find web/i18n/en-US -maxdepth 1 -type f -name '*.json' -print | sed -n 's@^.*/@@p' | sed 's/\.json$//' | sort | tr '\n' ' ' | sed 's/[[:space:]]*$//') + fi + elif [ "$SYNC_MODE" = "full" ]; then + echo "workflow_dispatch full mode requires the files input to stay within CI limits." >&2 + exit 1 + else + CHANGED_FILES="" + fi fi - # Truncate diff if too large (keep first 50KB) - if [ -f /tmp/i18n-diff.txt ]; then - head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt - mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt + FILE_ARGS="" + if [ -n "$CHANGED_FILES" ]; then + FILE_ARGS="--file $CHANGED_FILES" fi - echo "Detected files: $(cat $GITHUB_OUTPUT | grep CHANGED_FILES || echo 'none')" + LANG_ARGS="" + if [ -n "$TARGET_LANGS" ]; then + LANG_ARGS="--lang $TARGET_LANGS" + fi + + { + echo "DEFAULT_TARGET_LANGS=$DEFAULT_TARGET_LANGS" + echo "BASE_SHA=$BASE_SHA" + echo "HEAD_SHA=$HEAD_SHA" + echo "CHANGED_FILES=$CHANGED_FILES" + echo "TARGET_LANGS=$TARGET_LANGS" + echo "SYNC_MODE=$SYNC_MODE" + echo "FILE_ARGS=$FILE_ARGS" + echo "LANG_ARGS=$LANG_ARGS" + } >> "$GITHUB_OUTPUT" + + echo "Files: ${CHANGED_FILES:-}" + echo "Languages: ${TARGET_LANGS:-}" + echo "Mode: $SYNC_MODE" - name: Run Claude Code for Translation Sync - if: steps.detect_changes.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77 + if: steps.context.outputs.CHANGED_FILES != '' + uses: anthropics/claude-code-action@88c168b39e7e64da0286d812b6e9fbebb6708185 # v1.0.82 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} - # Allow github-actions bot to trigger this workflow via repository_dispatch - # See: https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md allowed_bots: 'github-actions[bot]' + show_full_output: ${{ github.event_name == 'workflow_dispatch' }} prompt: | - You are a professional i18n synchronization engineer for the Dify project. - Your task is to keep all language translations in sync with the English source (en-US). + You are the i18n sync agent for the Dify repository. + Your job is to keep translations synchronized with the English source files under `${{ github.workspace }}/web/i18n/en-US/`, then open a PR with the result. - ## CRITICAL TOOL RESTRICTIONS - - Use **Read** tool to read files (NOT cat or bash) - - Use **Edit** tool to modify JSON files (NOT node, jq, or bash scripts) - - Use **Bash** ONLY for: git commands, gh commands, pnpm commands - - Run bash commands ONE BY ONE, never combine with && or || - - NEVER use `$()` command substitution - it's not supported. Split into separate commands instead. + Use absolute paths at all times: + - Repo root: `${{ github.workspace }}` + - Web directory: `${{ github.workspace }}/web` + - Language config: `${{ github.workspace }}/web/i18n-config/languages.ts` - ## WORKING DIRECTORY & ABSOLUTE PATHS - Claude Code sandbox working directory may vary. Always use absolute paths: - - For pnpm: `pnpm --dir ${{ github.workspace }}/web ` - - For git: `git -C ${{ github.workspace }} ` - - For gh: `gh --repo ${{ github.repository }} ` - - For file paths: `${{ github.workspace }}/web/i18n/` + Inputs: + - Files in scope: `${{ steps.context.outputs.CHANGED_FILES }}` + - Target languages: `${{ steps.context.outputs.TARGET_LANGS }}` + - Sync mode: `${{ steps.context.outputs.SYNC_MODE }}` + - Base SHA: `${{ steps.context.outputs.BASE_SHA }}` + - Head SHA: `${{ steps.context.outputs.HEAD_SHA }}` + - Scoped file args: `${{ steps.context.outputs.FILE_ARGS }}` + - Scoped language args: `${{ steps.context.outputs.LANG_ARGS }}` - ## EFFICIENCY RULES - - **ONE Edit per language file** - batch all key additions into a single Edit - - Insert new keys at the beginning of JSON (after `{`), lint:fix will sort them - - Translate ALL keys for a language mentally first, then do ONE Edit - - ## Context - - Changed/target files: ${{ steps.detect_changes.outputs.CHANGED_FILES }} - - Target languages (empty means all supported): ${{ steps.detect_changes.outputs.TARGET_LANGS }} - - Sync mode: ${{ steps.detect_changes.outputs.SYNC_MODE }} - - Translation files are located in: ${{ github.workspace }}/web/i18n/{locale}/{filename}.json - - Language configuration is in: ${{ github.workspace }}/web/i18n-config/languages.ts - - Git diff is available: ${{ steps.detect_changes.outputs.DIFF_AVAILABLE }} - - ## CRITICAL DESIGN: Verify First, Then Sync - - You MUST follow this three-phase approach: - - ═══════════════════════════════════════════════════════════════ - ║ PHASE 1: VERIFY - Analyze and Generate Change Report ║ - ═══════════════════════════════════════════════════════════════ - - ### Step 1.1: Analyze Git Diff (for incremental mode) - Use the Read tool to read `/tmp/i18n-diff.txt` to see the git diff. - - Parse the diff to categorize changes: - - Lines with `+` (not `+++`): Added or modified values - - Lines with `-` (not `---`): Removed or old values - - Identify specific keys for each category: - * ADD: Keys that appear only in `+` lines (new keys) - * UPDATE: Keys that appear in both `-` and `+` lines (value changed) - * DELETE: Keys that appear only in `-` lines (removed keys) - - ### Step 1.2: Read Language Configuration - Use the Read tool to read `${{ github.workspace }}/web/i18n-config/languages.ts`. - Extract all languages with `supported: true`. - - ### Step 1.3: Run i18n:check for Each Language - ```bash - pnpm --dir ${{ github.workspace }}/web install --frozen-lockfile - ``` - ```bash - pnpm --dir ${{ github.workspace }}/web run i18n:check - ``` - - This will report: - - Missing keys (need to ADD) - - Extra keys (need to DELETE) - - ### Step 1.4: Generate Change Report - - Create a structured report identifying: - ``` - ╔══════════════════════════════════════════════════════════════╗ - ║ I18N SYNC CHANGE REPORT ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ Files to process: [list] ║ - ║ Languages to sync: [list] ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ ADD (New Keys): ║ - ║ - [filename].[key]: "English value" ║ - ║ ... ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ UPDATE (Modified Keys - MUST re-translate): ║ - ║ - [filename].[key]: "Old value" → "New value" ║ - ║ ... ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ DELETE (Extra Keys): ║ - ║ - [language]/[filename].[key] ║ - ║ ... ║ - ╚══════════════════════════════════════════════════════════════╝ - ``` - - **IMPORTANT**: For UPDATE detection, compare git diff to find keys where - the English value changed. These MUST be re-translated even if target - language already has a translation (it's now stale!). - - ═══════════════════════════════════════════════════════════════ - ║ PHASE 2: SYNC - Execute Changes Based on Report ║ - ═══════════════════════════════════════════════════════════════ - - ### Step 2.1: Process ADD Operations (BATCH per language file) - - **CRITICAL WORKFLOW for efficiency:** - 1. First, translate ALL new keys for ALL languages mentally - 2. Then, for EACH language file, do ONE Edit operation: - - Read the file once - - Insert ALL new keys at the beginning (right after the opening `{`) - - Don't worry about alphabetical order - lint:fix will sort them later - - Example Edit (adding 3 keys to zh-Hans/app.json): - ``` - old_string: '{\n "accessControl"' - new_string: '{\n "newKey1": "translation1",\n "newKey2": "translation2",\n "newKey3": "translation3",\n "accessControl"' - ``` - - **IMPORTANT**: - - ONE Edit per language file (not one Edit per key!) - - Always use the Edit tool. NEVER use bash scripts, node, or jq. - - ### Step 2.2: Process UPDATE Operations - - **IMPORTANT: Special handling for zh-Hans and ja-JP** - If zh-Hans or ja-JP files were ALSO modified in the same push: - - Run: `git -C ${{ github.workspace }} diff HEAD~1 --name-only` and check for zh-Hans or ja-JP files - - If found, it means someone manually translated them. Apply these rules: - - 1. **Missing keys**: Still ADD them (completeness required) - 2. **Existing translations**: Compare with the NEW English value: - - If translation is **completely wrong** or **unrelated** → Update it - - If translation is **roughly correct** (captures the meaning) → Keep it, respect manual work - - When in doubt, **keep the manual translation** - - Example: - - English changed: "Save" → "Save Changes" - - Manual translation: "保存更改" → Keep it (correct meaning) - - Manual translation: "删除" → Update it (completely wrong) - - For other languages: - Use Edit tool to replace the old value with the new translation. - You can batch multiple updates in one Edit if they are adjacent. - - ### Step 2.3: Process DELETE Operations - For extra keys reported by i18n:check: - - Run: `pnpm --dir ${{ github.workspace }}/web run i18n:check --auto-remove` - - Or manually remove from target language JSON files - - ## Translation Guidelines - - - PRESERVE all placeholders exactly as-is: - - `{{variable}}` - Mustache interpolation - - `${variable}` - Template literal - - `content` - HTML tags - - `_one`, `_other` - Pluralization suffixes (these are KEY suffixes, not values) - - **CRITICAL: Variable names and tag names MUST stay in English - NEVER translate them** - - ✅ CORRECT examples: - - English: "{{count}} items" → Japanese: "{{count}} 個のアイテム" - - English: "{{name}} updated" → Korean: "{{name}} 업데이트됨" - - English: "{{email}}" → Chinese: "{{email}}" - - English: "Marketplace" → Japanese: "マーケットプレイス" - - ❌ WRONG examples (NEVER do this - will break the application): - - "{{count}}" → "{{カウント}}" ❌ (variable name translated to Japanese) - - "{{name}}" → "{{이름}}" ❌ (variable name translated to Korean) - - "{{email}}" → "{{邮箱}}" ❌ (variable name translated to Chinese) - - "" → "<メール>" ❌ (tag name translated) - - "" → "<自定义链接>" ❌ (component name translated) - - - Use appropriate language register (formal/informal) based on existing translations - - Match existing translation style in each language - - Technical terms: check existing conventions per language - - For CJK languages: no spaces between characters unless necessary - - For RTL languages (ar-TN, fa-IR): ensure proper text handling - - ## Output Format Requirements - - Alphabetical key ordering (if original file uses it) - - 2-space indentation - - Trailing newline at end of file - - Valid JSON (use proper escaping for special characters) - - ═══════════════════════════════════════════════════════════════ - ║ PHASE 3: RE-VERIFY - Confirm All Issues Resolved ║ - ═══════════════════════════════════════════════════════════════ - - ### Step 3.1: Run Lint Fix (IMPORTANT!) - ```bash - pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- 'i18n/**/*.json' - ``` - This ensures: - - JSON keys are sorted alphabetically (jsonc/sort-keys rule) - - Valid i18n keys (dify-i18n/valid-i18n-keys rule) - - No extra keys (dify-i18n/no-extra-keys rule) - - ### Step 3.2: Run Final i18n Check - ```bash - pnpm --dir ${{ github.workspace }}/web run i18n:check - ``` - - ### Step 3.3: Fix Any Remaining Issues - If check reports issues: - - Go back to PHASE 2 for unresolved items - - Repeat until check passes - - ### Step 3.4: Generate Final Summary - ``` - ╔══════════════════════════════════════════════════════════════╗ - ║ SYNC COMPLETED SUMMARY ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ Language │ Added │ Updated │ Deleted │ Status ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ zh-Hans │ 5 │ 2 │ 1 │ ✓ Complete ║ - ║ ja-JP │ 5 │ 2 │ 1 │ ✓ Complete ║ - ║ ... │ ... │ ... │ ... │ ... ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ i18n:check │ PASSED - All keys in sync ║ - ╚══════════════════════════════════════════════════════════════╝ - ``` - - ## Mode-Specific Behavior - - **SYNC_MODE = "incremental"** (default): - - Focus on keys identified from git diff - - Also check i18n:check output for any missing/extra keys - - Efficient for small changes - - **SYNC_MODE = "full"**: - - Compare ALL keys between en-US and each language - - Run i18n:check to identify all discrepancies - - Use for first-time sync or fixing historical issues - - ## Important Notes - - 1. Always run i18n:check BEFORE and AFTER making changes - 2. The check script is the source of truth for missing/extra keys - 3. For UPDATE scenario: git diff is the source of truth for changed values - 4. Create a single commit with all translation changes - 5. If any translation fails, continue with others and report failures - - ═══════════════════════════════════════════════════════════════ - ║ PHASE 4: COMMIT AND CREATE PR ║ - ═══════════════════════════════════════════════════════════════ - - After all translations are complete and verified: - - ### Step 4.1: Check for changes - ```bash - git -C ${{ github.workspace }} status --porcelain - ``` - - If there are changes: - - ### Step 4.2: Create a new branch and commit - Run these git commands ONE BY ONE (not combined with &&). - **IMPORTANT**: Do NOT use `$()` command substitution. Use two separate commands: - - 1. First, get the timestamp: - ```bash - date +%Y%m%d-%H%M%S - ``` - (Note the output, e.g., "20260115-143052") - - 2. Then create branch using the timestamp value: - ```bash - git -C ${{ github.workspace }} checkout -b chore/i18n-sync-20260115-143052 - ``` - (Replace "20260115-143052" with the actual timestamp from step 1) - - 3. Stage changes: - ```bash - git -C ${{ github.workspace }} add web/i18n/ - ``` - - 4. Commit: - ```bash - git -C ${{ github.workspace }} commit -m "chore(i18n): sync translations with en-US - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}" - ``` - - 5. Push: - ```bash - git -C ${{ github.workspace }} push origin HEAD - ``` - - ### Step 4.3: Create Pull Request - ```bash - gh pr create --repo ${{ github.repository }} --title "chore(i18n): sync translations with en-US" --body "## Summary - - This PR was automatically generated to sync i18n translation files. - - ### Changes - - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }} - - Files processed: ${{ steps.detect_changes.outputs.CHANGED_FILES }} - - ### Verification - - [x] \`i18n:check\` passed - - [x] \`lint:fix\` applied - - 🤖 Generated with Claude Code GitHub Action" --base main - ``` + Tool rules: + - Use Read for repository files. + - Use Edit for JSON updates. + - Use Bash only for `git`, `gh`, `pnpm`, and `date`. + - Run Bash commands one by one. Do not combine commands with `&&`, `||`, pipes, or command substitution. + Required execution plan: + 1. Resolve target languages. + - Use the provided `Target languages` value as the source of truth. + - If it is unexpectedly empty, read `${{ github.workspace }}/web/i18n-config/languages.ts` and use every language with `supported: true` except `en-US`. + 2. Stay strictly in scope. + - Only process the files listed in `Files in scope`. + - Only process the resolved target languages, never `en-US`. + - Do not touch unrelated i18n files. + - Do not modify `${{ github.workspace }}/web/i18n/en-US/`. + 3. Detect English changes per file. + - Read the current English JSON file for each file in scope. + - If sync mode is `incremental` and `Base SHA` is not empty, run: + `git -C ${{ github.workspace }} show :web/i18n/en-US/.json` + - If sync mode is `full` or `Base SHA` is empty, skip historical comparison and treat the current English file as the only source of truth for structural sync. + - If the file did not exist at Base SHA, treat all current keys as ADD. + - Compare previous and current English JSON to identify: + - ADD: key only in current + - UPDATE: key exists in both and the English value changed + - DELETE: key only in previous + - Do not rely on a truncated diff file. + 4. Run a scoped pre-check before editing: + - `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` + - Use this command as the source of truth for missing and extra keys inside the current scope. + 5. Apply translations. + - For every target language and scoped file: + - If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed. + - ADD missing keys. + - UPDATE stale translations when the English value changed. + - DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope. + - For `zh-Hans` and `ja-JP`, if the locale file also changed between Base SHA and Head SHA, preserve manual translations unless they are clearly wrong for the new English value. If in doubt, keep the manual translation. + - Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names. + - Match the existing terminology and register used by each locale. + - Prefer one Edit per file when stable, but prioritize correctness over batching. + 6. Verify only the edited files. + - Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- ` + - Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` + - If verification fails, fix the remaining problems before continuing. + 7. Create a PR only when there are changes in `web/i18n/`. + - Check `git -C ${{ github.workspace }} status --porcelain -- web/i18n/` + - Create branch `chore/i18n-sync-` + - Commit message: `chore(i18n): sync translations with en-US` + - Push the branch and open a PR against `main` + - PR title: `chore(i18n): sync translations with en-US` + - PR body: summarize files, languages, sync mode, and verification commands + 8. If there are no translation changes after verification, do not create a branch, commit, or PR. claude_args: | - --max-turns 150 + --max-turns 80 --allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep" diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml deleted file mode 100644 index 1caaddd47a..0000000000 --- a/.github/workflows/trigger-i18n-sync.yml +++ /dev/null @@ -1,66 +0,0 @@ -name: Trigger i18n Sync on Push - -# This workflow bridges the push event to repository_dispatch -# because claude-code-action doesn't support push events directly. -# See: https://github.com/langgenius/dify/issues/30743 - -on: - push: - branches: [main] - paths: - - 'web/i18n/en-US/*.json' - -permissions: - contents: write - -jobs: - trigger: - if: github.repository == 'langgenius/dify' - runs-on: ubuntu-latest - timeout-minutes: 5 - - steps: - - name: Checkout repository - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - fetch-depth: 0 - - - name: Detect changed files and generate diff - id: detect - run: | - BEFORE_SHA="${{ github.event.before }}" - # Handle edge case: force push may have null/zero SHA - if [ -z "$BEFORE_SHA" ] || [ "$BEFORE_SHA" = "0000000000000000000000000000000000000000" ]; then - BEFORE_SHA="HEAD~1" - fi - - # Detect changed i18n files - changed=$(git diff --name-only "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' 2>/dev/null | xargs -n1 basename 2>/dev/null | sed 's/.json$//' | tr '\n' ' ' || echo "") - echo "changed_files=$changed" >> $GITHUB_OUTPUT - - # Generate diff for context - git diff "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt - - # Truncate if too large (keep first 50KB to match receiving workflow) - head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt - mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt - - # Base64 encode the diff for safe JSON transport (portable, single-line) - diff_base64=$(base64 < /tmp/i18n-diff.txt | tr -d '\n') - echo "diff_base64=$diff_base64" >> $GITHUB_OUTPUT - - if [ -n "$changed" ]; then - echo "has_changes=true" >> $GITHUB_OUTPUT - echo "Detected changed files: $changed" - else - echo "has_changes=false" >> $GITHUB_OUTPUT - echo "No i18n changes detected" - fi - - - name: Trigger i18n sync workflow - if: steps.detect.outputs.has_changes == 'true' - uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1 - with: - token: ${{ secrets.GITHUB_TOKEN }} - event-type: i18n-sync - client-payload: '{"changed_files": "${{ steps.detect.outputs.changed_files }}", "diff_base64": "${{ steps.detect.outputs.diff_base64 }}", "sync_mode": "incremental", "trigger_sha": "${{ github.sha }}"}' diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index f45f2137d6..7c4cd0ba8c 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -14,7 +14,6 @@ jobs: strategy: matrix: python-version: - - "3.11" - "3.12" steps: diff --git a/.github/workflows/web-e2e.yml b/.github/workflows/web-e2e.yml new file mode 100644 index 0000000000..8035d1ef8e --- /dev/null +++ b/.github/workflows/web-e2e.yml @@ -0,0 +1,72 @@ +name: Web Full-Stack E2E + +on: + workflow_call: + +permissions: + contents: read + +concurrency: + group: web-e2e-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + test: + name: Web Full-Stack E2E + runs-on: ubuntu-latest + defaults: + run: + shell: bash + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Setup web dependencies + uses: ./.github/actions/setup-web + + - name: Install E2E package dependencies + working-directory: ./e2e + run: vp install --frozen-lockfile + + - name: Setup UV and Python + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + with: + enable-cache: true + python-version: "3.12" + cache-dependency-glob: api/uv.lock + + - name: Install API dependencies + run: uv sync --project api --dev + + - name: Install Playwright browser + working-directory: ./e2e + run: vp run e2e:install + + - name: Run isolated source-api and built-web Cucumber E2E tests + working-directory: ./e2e + env: + E2E_ADMIN_EMAIL: e2e-admin@example.com + E2E_ADMIN_NAME: E2E Admin + E2E_ADMIN_PASSWORD: E2eAdmin12345 + E2E_FORCE_WEB_BUILD: "1" + E2E_INIT_PASSWORD: E2eInit12345 + run: vp run e2e:full + + - name: Upload Cucumber report + if: ${{ !cancelled() }} + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: cucumber-report + path: e2e/cucumber-report + retention-days: 7 + + - name: Upload E2E logs + if: ${{ !cancelled() }} + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: e2e-logs + path: e2e/.logs + retention-days: 7 diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index d40cd4bfeb..8110a16355 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -22,8 +22,8 @@ jobs: strategy: fail-fast: false matrix: - shardIndex: [1, 2, 3, 4, 5, 6] - shardTotal: [6] + shardIndex: [1, 2, 3, 4] + shardTotal: [4] defaults: run: shell: bash @@ -66,7 +66,6 @@ jobs: - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: - fetch-depth: 0 persist-credentials: false - name: Setup web environment diff --git a/Makefile b/Makefile index 55871c86a7..c377b7c671 100644 --- a/Makefile +++ b/Makefile @@ -74,6 +74,12 @@ type-check: @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . @echo "✅ Type checks complete" +type-check-core: + @echo "📝 Running core type checks (basedpyright + mypy)..." + @./dev/basedpyright-check $(PATH_TO_CHECK) + @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . + @echo "✅ Core type checks complete" + test: @echo "🧪 Running backend unit tests..." @if [ -n "$(TARGET_TESTS)" ]; then \ @@ -133,6 +139,7 @@ help: @echo " make check - Check code with ruff" @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" @echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)" + @echo " make type-check-core - Run core type checks (basedpyright, mypy)" @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)" @echo "" @echo "Docker Build Targets:" diff --git a/README.md b/README.md index bef8f6b782..d9848a6c78 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features (including [Opik](https://www.comet.com/docs/opik/integrations/dify), [Langfuse](https://docs.langfuse.com), and [Arize Phoenix](https://docs.arize.com/phoenix)) and more, letting you quickly go from prototype to production. Here's a list of the core features: diff --git a/api/.env.example b/api/.env.example index 9672a99d55..c6541731e6 100644 --- a/api/.env.example +++ b/api/.env.example @@ -127,7 +127,8 @@ ALIYUN_OSS_AUTH_VERSION=v1 ALIYUN_OSS_REGION=your-region # Don't start with '/'. OSS doesn't support leading slash in object names. ALIYUN_OSS_PATH=your-path -ALIYUN_CLOUDBOX_ID=your-cloudbox-id +# Optional CloudBox ID for Aliyun OSS, DO NOT enable it if you are not using CloudBox. +#ALIYUN_CLOUDBOX_ID=your-cloudbox-id # Google Storage configuration GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name diff --git a/api/.importlinter b/api/.importlinter index c2841f64d2..5e06947d94 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -3,7 +3,6 @@ root_packages = core constants context - graphon configs controllers extensions @@ -13,152 +12,3 @@ root_packages = tasks services include_external_packages = True - -[importlinter:contract:workflow] -name = Workflow -type=layers -layers = - graph_engine - graph_events - graph - nodes - node_events - runtime - entities -containers = - graphon -ignore_imports = - graphon.nodes.base.node -> graphon.graph_events - graphon.nodes.iteration.iteration_node -> graphon.graph_events - graphon.nodes.loop.loop_node -> graphon.graph_events - - graphon.nodes.iteration.iteration_node -> graphon.graph_engine - graphon.nodes.loop.loop_node -> graphon.graph_engine - # TODO(QuantumGhost): fix the import violation later - graphon.entities.pause_reason -> graphon.nodes.human_input.entities - -[importlinter:contract:workflow-external-imports] -name = Workflow External Imports -type = forbidden -source_modules = - graphon -forbidden_modules = - constants - configs - context - controllers - extensions - factories - libs - models - services - tasks - core.agent - core.app - core.base - core.callback_handler - core.datasource - core.db - core.entities - core.errors - core.extension - core.external_data_tool - core.file - core.helper - core.hosting_configuration - core.indexing_runner - core.llm_generator - core.logging - core.mcp - core.memory - core.moderation - core.ops - core.plugin - core.prompt - core.provider_manager - core.rag - core.repositories - core.schemas - core.tools - core.trigger - core.variables - -[importlinter:contract:workflow-third-party-imports] -name = Workflow Third-Party Imports -type = forbidden -source_modules = - graphon -forbidden_modules = - sqlalchemy - -[importlinter:contract:rsc] -name = RSC -type = layers -layers = - graph_engine - response_coordinator -containers = - graphon.graph_engine - -[importlinter:contract:worker] -name = Worker -type = layers -layers = - graph_engine - worker -containers = - graphon.graph_engine - -[importlinter:contract:graph-engine-architecture] -name = Graph Engine Architecture -type = layers -layers = - graph_engine - orchestration - command_processing - event_management - error_handler - graph_traversal - graph_state_manager - worker_management - domain -containers = - graphon.graph_engine - -[importlinter:contract:domain-isolation] -name = Domain Model Isolation -type = forbidden -source_modules = - graphon.graph_engine.domain -forbidden_modules = - graphon.graph_engine.worker_management - graphon.graph_engine.command_channels - graphon.graph_engine.layers - graphon.graph_engine.protocols - -[importlinter:contract:worker-management] -name = Worker Management -type = forbidden -source_modules = - graphon.graph_engine.worker_management -forbidden_modules = - graphon.graph_engine.orchestration - graphon.graph_engine.command_processing - graphon.graph_engine.event_management - - -[importlinter:contract:graph-traversal-components] -name = Graph Traversal Components -type = layers -layers = - edge_processor - skip_propagator -containers = - graphon.graph_engine.graph_traversal - -[importlinter:contract:command-channels] -name = Command Channels Independence -type = independence -modules = - graphon.graph_engine.command_channels.in_memory_channel - graphon.graph_engine.command_channels.redis_channel diff --git a/api/app_factory.py b/api/app_factory.py index 066eb2ae2c..76838f9925 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -143,6 +143,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_compress, ext_database, + ext_enterprise_telemetry, ext_fastopenapi, ext_forward_refs, ext_hosting_provider, @@ -193,6 +194,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_fastopenapi, ext_otel, + ext_enterprise_telemetry, ext_request_logging, ext_session_factory, ] diff --git a/api/configs/app_config.py b/api/configs/app_config.py index d3b1cf9d5b..831f0a49e0 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings from libs.file_utils import search_file_upwards from .deploy import DeploymentConfig -from .enterprise import EnterpriseFeatureConfig +from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig from .extra import ExtraServiceConfig from .feature import FeatureConfig from .middleware import MiddlewareConfig @@ -73,6 +73,8 @@ class DifyConfig( # Enterprise feature configs # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** EnterpriseFeatureConfig, + # Enterprise telemetry configs + EnterpriseTelemetryConfig, ): model_config = SettingsConfigDict( # read from dotenv format config file diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index f8447c6979..8a6a921a4e 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -22,3 +22,52 @@ class EnterpriseFeatureConfig(BaseSettings): ENTERPRISE_REQUEST_TIMEOUT: int = Field( ge=1, description="Maximum timeout in seconds for enterprise requests", default=5 ) + + +class EnterpriseTelemetryConfig(BaseSettings): + """ + Configuration for enterprise telemetry. + """ + + ENTERPRISE_TELEMETRY_ENABLED: bool = Field( + description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).", + default=False, + ) + + ENTERPRISE_OTLP_ENDPOINT: str = Field( + description="Enterprise OTEL collector endpoint.", + default="", + ) + + ENTERPRISE_OTLP_HEADERS: str = Field( + description="Auth headers for OTLP export (key=value,key2=value2).", + default="", + ) + + ENTERPRISE_OTLP_PROTOCOL: str = Field( + description="OTLP protocol: 'http' or 'grpc' (default: http).", + default="http", + ) + + ENTERPRISE_OTLP_API_KEY: str = Field( + description="Bearer token for enterprise OTLP export authentication.", + default="", + ) + + ENTERPRISE_INCLUDE_CONTENT: bool = Field( + description="Include input/output content in traces (privacy toggle).", + # Setting the default value to False to avoid accidentally log PII data in traces. + default=False, + ) + + ENTERPRISE_SERVICE_NAME: str = Field( + description="Service name for OTEL resource.", + default="dify", + ) + + ENTERPRISE_OTEL_SAMPLING_RATE: float = Field( + description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).", + default=1.0, + ge=0.0, + le=1.0, + ) diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 515a6a5125..7348ef62aa 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -2,9 +2,9 @@ from __future__ import annotations from typing import Any, TypeAlias +from graphon.file import helpers as file_helpers from pydantic import BaseModel, ConfigDict, computed_field -from graphon.file import helpers as file_helpers from models.model import IconType JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any] diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 357697ed30..738e77b371 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -5,6 +5,8 @@ from typing import Any, Literal, TypeAlias from flask import request from flask_restx import Resource +from graphon.enums import WorkflowExecutionStatus +from graphon.file import helpers as file_helpers from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session @@ -27,8 +29,6 @@ from core.ops.ops_trace_manager import OpsTraceManager from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.trigger.constants import TRIGGER_NODE_TYPES from extensions.ext_database import db -from graphon.enums import WorkflowExecutionStatus -from graphon.file import helpers as file_helpers from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow from models.model import IconType diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 91fbe4a85a..78ddb904e1 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import Resource, fields +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -22,7 +23,6 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode from services.audio_service import AudioService diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index fe274e4c9a..d83925d173 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -3,6 +3,7 @@ from typing import Any, Literal from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -26,7 +27,6 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user, login_required diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index c720a5e074..7101d5df7b 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from controllers.console import console_ns @@ -19,7 +20,6 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db -from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_account_with_tenant, login_required from models import App from services.workflow_service import WorkflowService diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index dc752939ae..2afe276742 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -3,6 +3,7 @@ from typing import Literal from flask import request from flask_restx import Resource, fields, marshal_with +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from sqlalchemy import exists, func, select from werkzeug.exceptions import InternalServerError, NotFound @@ -26,7 +27,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from extensions.ext_database import db from fields.raws import FilesContainedField -from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import TimestampField, uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import current_account_with_tenant, login_required diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 2737dd1dfd..1f5a84c0b2 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -5,6 +5,10 @@ from typing import Any from flask import abort, request from flask_restx import Resource, fields, marshal_with +from graphon.enums import NodeType +from graphon.file import File +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -35,10 +39,6 @@ from extensions.ext_redis import redis_client from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields -from graphon.enums import NodeType -from graphon.file.models import File -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, uuid_value diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 8cf0004b09..f0e26c86a5 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -3,6 +3,7 @@ from datetime import datetime from dateutil.parser import isoparse from flask import request from flask_restx import Resource, marshal_with +from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session @@ -14,7 +15,6 @@ from fields.workflow_app_log_fields import ( build_workflow_app_log_pagination_model, build_workflow_archived_log_pagination_model, ) -from graphon.enums import WorkflowExecutionStatus from libs.login import login_required from models import App from models.model import AppMode diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 657b072490..4052897e9a 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -5,6 +5,10 @@ from typing import Any, NoReturn, ParamSpec, TypeVar from flask import Response, request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.file import helpers as file_helpers +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import Session @@ -20,10 +24,6 @@ from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTE from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type -from graphon.file import helpers as file_helpers -from graphon.variables.segment_group import SegmentGroup -from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment -from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import App, AppMode from models.workflow import WorkflowDraftVariable diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 29fa96c4e6..83e8bedc11 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,8 +1,10 @@ from datetime import UTC, datetime, timedelta -from typing import Literal, cast +from typing import Literal, TypedDict, cast from flask import request from flask_restx import Resource, fields, marshal_with +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -26,8 +28,6 @@ from fields.workflow_run_fields import ( workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from libs.custom_inputs import time_duration from libs.helper import uuid_value @@ -173,6 +173,23 @@ console_ns.schema_model( ) +class HumanInputPauseTypeResponse(TypedDict): + type: Literal["human_input"] + form_id: str + backstage_input_url: str | None + + +class PausedNodeResponse(TypedDict): + node_id: str + node_title: str + pause_type: HumanInputPauseTypeResponse + + +class WorkflowPauseDetailsResponse(TypedDict): + paused_at: str | None + paused_nodes: list[PausedNodeResponse] + + @console_ns.route("/apps//advanced-chat/workflow-runs") class AdvancedChatAppWorkflowRunListApi(Resource): @console_ns.doc("get_advanced_chat_workflow_runs") @@ -490,10 +507,11 @@ class ConsoleWorkflowPauseDetailsApi(Resource): # Check if workflow is suspended is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED if not is_paused: - return { + empty_response: WorkflowPauseDetailsResponse = { "paused_at": None, "paused_nodes": [], - }, 200 + } + return empty_response, 200 pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) pause_reasons = pause_entity.get_pause_reasons() if pause_entity else [] @@ -503,8 +521,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource): # Build response paused_at = pause_entity.paused_at if pause_entity else None - paused_nodes = [] - response = { + paused_nodes: list[PausedNodeResponse] = [] + response: WorkflowPauseDetailsResponse = { "paused_at": paused_at.isoformat() + "Z" if paused_at else None, "paused_nodes": paused_nodes, } diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 665a80802d..686b865871 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -4,11 +4,11 @@ from typing import Concatenate, ParamSpec, TypeVar from flask import jsonify, request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models import Account from models.model import OAuthProviderApp diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 5d704b6224..f23c7eb431 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -2,6 +2,7 @@ from typing import Any, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, select from werkzeug.exceptions import Forbidden, NotFound @@ -51,7 +52,6 @@ from fields.dataset_fields import ( weighted_score_fields, ) from fields.document_fields import document_status_fields -from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index edb738aad8..ab367d8483 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -9,6 +9,8 @@ from uuid import UUID import sqlalchemy as sa from flask import request, send_file from flask_restx import Resource, fields, marshal, marshal_with +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import BaseModel, Field from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound @@ -37,8 +39,6 @@ from fields.document_fields import ( document_status_fields, document_with_segments_fields, ) -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DatasetProcessRule, Document, DocumentSegment, UploadFile diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 2fd84303d7..c5f4e3a6e2 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -2,6 +2,7 @@ import uuid from flask import request from flask_restx import Resource, marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import String, cast, func, or_, select from sqlalchemy.dialects.postgresql import JSONB @@ -30,7 +31,6 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields -from graphon.model_runtime.entities.model_entities import ModelType from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 699fa599c8..8fb3699849 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -2,6 +2,7 @@ import logging from typing import Any from flask_restx import marshal +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -20,7 +21,6 @@ from core.errors.error import ( QuotaExceededError, ) from fields.hit_testing_fields import hit_testing_record_fields -from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_user from models.account import Account from services.dataset_service import DatasetService diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 946fa599e6..1976a6bc8a 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -2,6 +2,8 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, NotFound @@ -10,8 +12,6 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.plugin.impl.oauth import OAuthHandler -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 977ae93c03..f12cbd3495 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -3,6 +3,7 @@ from typing import Any, NoReturn from flask import Response, request from flask_restx import Resource, marshal, marshal_with +from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -26,7 +27,6 @@ from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTE from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type -from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import Account from models.dataset import Pipeline diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 9079fbc29a..8efb59a8e9 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -4,6 +4,7 @@ from typing import Any, Literal, cast from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -39,7 +40,6 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from factories import variable_factory -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant, current_user, login_required @@ -53,6 +53,7 @@ from services.rag_pipeline.pipeline_generate_service import PipelineGenerateServ from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTransformService +from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService logger = logging.getLogger(__name__) @@ -781,7 +782,38 @@ class RagPipelineByIdApi(Resource): # Commit the transaction in the controller session.commit() - return workflow + return workflow + + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @get_rag_pipeline + def delete(self, pipeline: Pipeline, workflow_id: str): + """ + Delete a published workflow version that is not currently active on the pipeline. + """ + if pipeline.workflow_id == workflow_id: + abort(400, description=f"Cannot delete workflow that is currently in use by pipeline '{pipeline.id}'") + + workflow_service = WorkflowService() + + with Session(db.engine) as session: + try: + workflow_service.delete_workflow( + session=session, + workflow_id=workflow_id, + tenant_id=pipeline.tenant_id, + ) + session.commit() + except WorkflowInUseError as e: + abort(400, description=str(e)) + except DraftWorkflowDeletionError as e: + abort(400, description=str(e)) + except ValueError as e: + raise NotFound(str(e)) + + return None, 204 @console_ns.route("/rag/pipelines//workflows/published/processing/parameters") diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index bc78ee6d2d..b1b01b5f51 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,6 +1,7 @@ import logging from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -19,7 +20,6 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index ccdccceaa6..eacd7332fe 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -2,6 +2,7 @@ import logging from typing import Any, Literal from uuid import UUID +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -25,7 +26,6 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_database import db -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index a72cf6328a..fcbefcda33 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -2,6 +2,7 @@ import logging from typing import Literal from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from libs.login import current_account_with_tenant diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 26aa086aac..e432574434 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -3,6 +3,8 @@ from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -59,8 +61,6 @@ from fields.workflow_fields import ( workflow_fields, workflow_partial_fields, ) -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 17dbbdd534..42cafc7193 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,6 +1,8 @@ import logging from typing import Any +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel from werkzeug.exceptions import InternalServerError @@ -22,8 +24,6 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.login import current_account_with_tenant from models.model import AppMode, InstalledApp diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 7207f7fd1d..e37e78c966 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -15,6 +15,7 @@ from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.apps.message_generator import MessageGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator @@ -166,6 +167,7 @@ class ConsoleWorkflowEventsApi(Resource): else: msg_generator = MessageGenerator() + generator: BaseAppGenerator if app.mode == AppMode.ADVANCED_CHAT: generator = AdvancedChatAppGenerator() elif app.mode == AppMode.WORKFLOW: @@ -202,7 +204,7 @@ class ConsoleWorkflowEventsApi(Resource): ) -def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun): +def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun) -> App: query = select(App).where( App.id == workflow_run.app_id, App.tenant_id == workflow_run.tenant_id, diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 2a46d2250a..551c86fd82 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -2,6 +2,7 @@ import urllib.parse import httpx from flask_restx import Resource +from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field import services @@ -15,7 +16,6 @@ from controllers.console import console_ns from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo -from graphon.file import helpers as file_helpers from libs.login import current_account_with_tenant, login_required from services.file_service import FileService diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 764f488755..3fdcbc4710 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,8 +1,8 @@ from flask_restx import Resource, fields +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.agent_service import AgentService diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index f45b72f390..b6b9deb1f9 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -2,13 +2,13 @@ from typing import Any from flask import request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginPermissionDeniedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 2a6f37aec8..e4cfca9fa4 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,12 +1,12 @@ from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index b22b91706e..8e0aefc9e3 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -3,13 +3,13 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 3c7b97d7fc..2ec1a9435a 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -3,14 +3,14 @@ from typing import Any, cast from flask import request from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index 6564ff5e7f..aa674a63b3 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -4,6 +4,7 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden @@ -14,7 +15,6 @@ from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginDaemonClientSideError -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService @@ -200,7 +200,7 @@ class PluginDebuggingKeyApi(Resource): "port": dify_config.PLUGIN_REMOTE_INSTALL_PORT, } except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/list") @@ -215,7 +215,7 @@ class PluginListApi(Resource): try: plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) @@ -232,7 +232,7 @@ class PluginListLatestVersionsApi(Resource): try: versions = PluginService.list_latest_versions(args.plugin_ids) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"versions": versions}) @@ -251,7 +251,7 @@ class PluginListInstallationsFromIdsApi(Resource): try: plugins = PluginService.list_installations_from_ids(tenant_id, args.plugin_ids) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"plugins": plugins}) @@ -266,7 +266,7 @@ class PluginIconApi(Resource): try: icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) @@ -286,7 +286,7 @@ class PluginAssetApi(Resource): binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name) return send_file(io.BytesIO(binary), mimetype="application/octet-stream") except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/upload/pkg") @@ -303,7 +303,7 @@ class PluginUploadFromPkgApi(Resource): try: response = PluginService.upload_pkg(tenant_id, content) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -323,7 +323,7 @@ class PluginUploadFromGithubApi(Resource): try: response = PluginService.upload_pkg_from_github(tenant_id, args.repo, args.version, args.package) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -361,7 +361,7 @@ class PluginInstallFromPkgApi(Resource): try: response = PluginService.install_from_local_pkg(tenant_id, args.plugin_unique_identifiers) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -387,7 +387,7 @@ class PluginInstallFromGithubApi(Resource): args.package, ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -407,7 +407,7 @@ class PluginInstallFromMarketplaceApi(Resource): try: response = PluginService.install_from_marketplace_pkg(tenant_id, args.plugin_unique_identifiers) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -433,7 +433,7 @@ class PluginFetchMarketplacePkgApi(Resource): } ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/fetch-manifest") @@ -453,7 +453,7 @@ class PluginFetchManifestApi(Resource): {"manifest": PluginService.fetch_plugin_manifest(tenant_id, args.plugin_unique_identifier).model_dump()} ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks") @@ -471,7 +471,7 @@ class PluginFetchInstallTasksApi(Resource): try: return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)}) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks/") @@ -486,7 +486,7 @@ class PluginFetchInstallTaskApi(Resource): try: return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)}) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks//delete") @@ -501,7 +501,7 @@ class PluginDeleteInstallTaskApi(Resource): try: return {"success": PluginService.delete_install_task(tenant_id, task_id)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks/delete_all") @@ -516,7 +516,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource): try: return {"success": PluginService.delete_all_install_task_items(tenant_id)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks//delete/") @@ -531,7 +531,7 @@ class PluginDeleteInstallTaskItemApi(Resource): try: return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/upgrade/marketplace") @@ -553,7 +553,7 @@ class PluginUpgradeFromMarketplaceApi(Resource): ) ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/upgrade/github") @@ -580,7 +580,7 @@ class PluginUpgradeFromGithubApi(Resource): ) ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/uninstall") @@ -598,7 +598,7 @@ class PluginUninstallApi(Resource): try: return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/permission/change") @@ -674,7 +674,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource): provider_type=args.provider_type, ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"options": options}) @@ -705,7 +705,7 @@ class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource): credentials=args.credentials, ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"options": options}) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 1273b85bc3..02eb0adc94 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -5,6 +5,7 @@ from urllib.parse import urlparse from flask import make_response, redirect, request, send_file from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -27,7 +28,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index feedf074b7..265b6ecd9a 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -3,6 +3,7 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden @@ -15,7 +16,6 @@ from core.plugin.impl.oauth import OAuthHandler from core.trigger.entities.entities import SubscriptionBuilderUpdater from core.trigger.trigger_manager import TriggerManager from extensions.ext_database import db -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_user, login_required from models.account import Account from models.provider_ids import TriggerProviderID diff --git a/api/controllers/inner_api/app/dsl.py b/api/controllers/inner_api/app/dsl.py index 56730cf37a..3b673d6e1d 100644 --- a/api/controllers/inner_api/app/dsl.py +++ b/api/controllers/inner_api/app/dsl.py @@ -8,6 +8,7 @@ Go admin-api caller. from flask import request from flask_restx import Resource from pydantic import BaseModel, Field +from sqlalchemy import select from sqlalchemy.orm import Session from controllers.common.schema import register_schema_model @@ -87,7 +88,7 @@ class EnterpriseAppDSLExport(Resource): """Export an app's DSL as YAML.""" include_secret = request.args.get("include_secret", "false").lower() == "true" - app_model = db.session.query(App).filter_by(id=app_id).first() + app_model = db.session.get(App, app_id) if not app_model: return {"message": "app not found"}, 404 @@ -104,7 +105,7 @@ def _get_active_account(email: str) -> Account | None: Workspace membership is already validated by the Go admin-api caller. """ - account = db.session.query(Account).filter_by(email=email).first() + account = db.session.scalar(select(Account).where(Account.email == email).limit(1)) if account is None or account.status != AccountStatus.ACTIVE: return None return account diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 72cab3de73..83c8fa02fe 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,4 +1,5 @@ from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns @@ -29,7 +30,6 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.signature import get_signed_file_url_for_plugin -from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import length_prefixed_response from models import Account, Tenant from models.model import EndUser diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 869fb73cf5..3d00f77e79 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -2,6 +2,7 @@ from typing import Any, Union from flask import Response from flask_restx import Resource +from graphon.variables.input_entities import VariableEntity from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import Session @@ -10,7 +11,6 @@ from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request from extensions.ext_database import db -from graphon.variables.input_entities import VariableEntity from libs import helper from models.enums import AppMCPServerStatus from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 86d88ddafb..6228cfc25b 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -21,7 +22,6 @@ from controllers.service_api.app.error import ( ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 31f2797d66..3142e5118e 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -4,6 +4,7 @@ from uuid import UUID from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -28,7 +29,6 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 94afd47f7f..1759075139 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -4,6 +4,9 @@ from typing import Any, Literal from dateutil.parser import isoparse from flask import request from flask_restx import Namespace, Resource, fields +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -30,9 +33,6 @@ from core.helper.trace_id_helper import get_external_trace_id from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model -from graphon.enums import WorkflowExecutionStatus -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import OptionalTimestampField, TimestampField from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index dcf788f7a8..80205b283b 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -2,6 +2,7 @@ from typing import Any, Literal, cast from flask import request from flask_restx import marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import Forbidden, NotFound @@ -18,7 +19,6 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage from core.rag.index_processor.constant.index_type import IndexTechniqueType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag -from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 28fa915117..b4cc9874b6 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -2,6 +2,7 @@ from typing import Any from flask import request from flask_restx import marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select from werkzeug.exceptions import NotFound @@ -21,7 +22,6 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields -from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 5ac65fc4e6..c0a6cb0a76 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -1,9 +1,9 @@ from flask_login import current_user from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token -from graphon.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 8081dee0bd..9ba1dc4a3a 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import fields, marshal_with +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, field_validator from werkzeug.exceptions import InternalServerError @@ -20,7 +21,6 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 0528184d79..e37f9af5f0 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,6 +1,7 @@ import logging from typing import Any, Literal +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -25,7 +26,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.model import AppMode diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 4274b8c9ab..c5505dd60d 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -2,6 +2,7 @@ import logging from typing import Literal from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -22,7 +23,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.enums import FeedbackRating diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index fe31e9d4ac..38aeccc642 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,6 +1,7 @@ import urllib.parse import httpx +from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field, HttpUrl import services @@ -13,7 +14,6 @@ from controllers.common.errors import ( from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo -from graphon.file import helpers as file_helpers from services.file_service import FileService from ..common.schema import register_schema_models diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index ccef6e5b7f..7f5521f9f5 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,6 +1,8 @@ import logging from typing import Any +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -23,8 +25,6 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index a846cf4b0f..06c746990d 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -4,7 +4,21 @@ import uuid from decimal import Decimal from typing import Union, cast -from sqlalchemy import select +from graphon.file import file_manager +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMUsage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from sqlalchemy import func, select from core.agent.entities import AgentEntity, AgentToolEntity from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -29,20 +43,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from factories import file_factory -from graphon.file import file_manager -from graphon.model_runtime.entities import ( - AssistantPromptMessage, - LLMUsage, - PromptMessage, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from graphon.model_runtime.entities.model_entities import ModelFeature -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile @@ -104,11 +104,14 @@ class BaseAgentRunner(AppRunner): ) # get how many agent thoughts have been created self.agent_thought_count = ( - db.session.query(MessageAgentThought) - .where( - MessageAgentThought.message_id == self.message.id, + db.session.scalar( + select(func.count()) + .select_from(MessageAgentThought) + .where( + MessageAgentThought.message_id == self.message.id, + ) ) - .count() + or 0 ) db.session.close() diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 0a0fdfdd29..11e2aa062d 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -4,6 +4,15 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence from typing import Any +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + ToolPromptMessage, + UserPromptMessage, +) + from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError @@ -15,14 +24,6 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index b3fc8d42e6..a4c438e929 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,6 +1,5 @@ import json -from core.agent.cot_agent_runner import CotAgentRunner from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -12,6 +11,8 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.model_runtime.utils.encoders import jsonable_encoder +from core.agent.cot_agent_runner import CotAgentRunner + class CotChatAgentRunner(CotAgentRunner): def _organize_system_prompt(self) -> SystemPromptMessage: diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 51a30998ae..d4c52a8eb1 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,6 +1,5 @@ import json -from core.agent.cot_agent_runner import CotAgentRunner from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -9,6 +8,8 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.utils.encoders import jsonable_encoder +from core.agent.cot_agent_runner import CotAgentRunner + class CotCompletionAgentRunner(CotAgentRunner): def _organize_instruction_prompt(self) -> str: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index d38d24d1e7..fdffde85d0 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -4,13 +4,6 @@ from collections.abc import Generator from copy import deepcopy from typing import Any, Union -from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.errors import AgentMaxIterationError -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -26,6 +19,14 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes + +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.errors import AgentMaxIterationError +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index c3e56fe011..46c1f1230d 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -3,9 +3,10 @@ import re from collections.abc import Generator from typing import Union -from core.agent.entities import AgentScratchpadUnit from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from core.agent.entities import AgentScratchpadUnit + class CotAgentOutputParser: @classmethod diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index dbd7527fc6..b7dd55632e 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -1,13 +1,14 @@ from typing import cast +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class ModelConfigConverter: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index f279f769aa..5cc385c378 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,9 +1,10 @@ from collections.abc import Mapping from typing import Any +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType + from core.app.app_config.entities import ModelConfigEntity from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 7715a5330a..76196e7034 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,5 +1,7 @@ from typing import Any +from graphon.model_runtime.entities.message_entities import PromptMessageRole + from core.app.app_config.entities import ( AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, @@ -7,7 +9,6 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.simple_prompt_transform import ModelMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode, AppModelConfigDict diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 6d63ae04d3..f0b71c5801 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,9 +1,10 @@ import re from typing import cast +from graphon.variables.input_entities import VariableEntity, VariableEntityType + from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import AppModelConfigDict _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index c67412cc29..536617edba 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -2,13 +2,13 @@ from collections.abc import Sequence from enum import StrEnum, auto from typing import Any, Literal -from pydantic import BaseModel, Field - -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from graphon.file import FileUploadConfig from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.entities.message_entities import PromptMessageRole from graphon.variables.input_entities import VariableEntity as WorkflowVariableEntity +from pydantic import BaseModel, Field + +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from models.model import AppMode diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 9092c1a17d..e96517c426 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,9 +1,10 @@ from collections.abc import Mapping from typing import Any -from constants import DEFAULT_FILE_NUMBER_LIMITS from graphon.file import FileUploadConfig +from constants import DEFAULT_FILE_NUMBER_LIMITS + class FileUploadConfigManager: @classmethod diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 13ace32fd6..62e0c31d1a 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,7 +1,8 @@ import re -from core.app.app_config.entities import RagPipelineVariableEntity from graphon.variables.input_entities import VariableEntity + +from core.app.app_config.entities import RagPipelineVariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 853cbb426c..aa2b65766f 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -5,7 +5,7 @@ import logging import threading import uuid from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -18,11 +18,21 @@ from constants import UUID_NIL if TYPE_CHECKING: from controllers.console.app.workflow import LoopNodeRunPayload +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader + from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter -from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.advanced_chat.generate_task_pipeline import ( + AdvancedChatAppGenerateTaskPipeline, + ConversationSnapshot, + MessageSnapshot, + WorkflowSnapshot, +) from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError @@ -38,13 +48,8 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom -from models.base import Base from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService from services.workflow_draft_variable_service import ( @@ -524,19 +529,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): worker_thread.start() - # release database connection, because the following new thread operations may take a long time - with Session(bind=db.engine, expire_on_commit=False) as session: - workflow = _refresh_model(session, workflow) - message = _refresh_model(session, message) + # Capture the scalar fields needed by the response pipeline before + # releasing the request-scoped SQLAlchemy session. + workflow_snapshot = WorkflowSnapshot.from_workflow(workflow) + conversation_snapshot = ConversationSnapshot.from_conversation(conversation) + message_snapshot = MessageSnapshot.from_message(message) db.session.close() # return response or stream generator response = self._handle_advanced_chat_response( application_generate_entity=application_generate_entity, - workflow=workflow, + workflow=workflow_snapshot, queue_manager=queue_manager, - conversation=conversation, - message=message, + conversation=conversation_snapshot, + message=message_snapshot, user=user, stream=stream, draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user), @@ -643,10 +649,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): self, *, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, + workflow: WorkflowSnapshot, queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, + conversation: ConversationSnapshot, + message: MessageSnapshot, user: Union[Account, EndUser], draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, @@ -683,13 +689,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): else: logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id) raise e - - -_T = TypeVar("_T", bound=Base) - - -def _refresh_model(session, model: _T) -> _T: - with Session(bind=db.engine, expire_on_commit=False) as session: - detach_model = session.get(type(model), model.id) - assert detach_model is not None - return detach_model diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index d21fce144e..a884a1c7f9 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -3,6 +3,12 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import Variable from sqlalchemy import select from sqlalchemy.orm import Session @@ -37,12 +43,6 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels.redis_channel import RedisChannel -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import Variable from models import Workflow from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 51febed32a..5203de225c 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,9 +4,17 @@ import re import time from collections.abc import Callable, Generator, Mapping from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime from threading import Thread from typing import Any, Union +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session @@ -69,21 +77,63 @@ from core.repositories.human_input_repository import HumanInputFormRepositoryImp from core.workflow.file_reference import resolve_file_record_id from core.workflow.system_variables import build_system_variables from extensions.ext_database import db -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus from models.execution_extra_content import HumanInputContent +from models.model import AppMode from models.workflow import Workflow logger = logging.getLogger(__name__) +@dataclass(frozen=True, slots=True) +class WorkflowSnapshot: + id: str + tenant_id: str + features_dict: Mapping[str, Any] + + @classmethod + def from_workflow(cls, workflow: Workflow) -> "WorkflowSnapshot": + return cls( + id=workflow.id, + tenant_id=workflow.tenant_id, + features_dict=dict(workflow.features_dict), + ) + + +@dataclass(frozen=True, slots=True) +class ConversationSnapshot: + id: str + mode: AppMode + + @classmethod + def from_conversation(cls, conversation: Conversation) -> "ConversationSnapshot": + return cls( + id=conversation.id, + mode=conversation.mode, + ) + + +@dataclass(frozen=True, slots=True) +class MessageSnapshot: + id: str + query: str + created_at: datetime + status: MessageStatus + answer: str + + @classmethod + def from_message(cls, message: Message) -> "MessageSnapshot": + return cls( + id=message.id, + query=message.query, + created_at=message.created_at, + status=message.status, + answer=message.answer, + ) + + class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. @@ -92,10 +142,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def __init__( self, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, + workflow: WorkflowSnapshot, queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, + conversation: ConversationSnapshot, + message: MessageSnapshot, user: Union[Account, EndUser], stream: bool, dialogue_count: int, @@ -156,7 +206,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._message_saved_on_pause = False self._seed_graph_runtime_state_from_queue_manager() - def _seed_task_state_from_message(self, message: Message) -> None: + def _seed_task_state_from_message(self, message: MessageSnapshot) -> None: if message.status == MessageStatus.PAUSED and message.answer: self._task_state.answer = message.answer diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 1a44cc235e..bb258af4c1 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -6,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, In from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 09ddce327e..a20d3f3c38 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,6 +1,9 @@ import logging from typing import cast +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from core.agent.cot_chat_agent_runner import CotChatAgentRunner @@ -16,9 +19,6 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError from extensions.ext_database import db -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 5c9ba4567a..66390116d4 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -3,10 +3,11 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping from typing import Any, Union +from graphon.model_runtime.errors.invoke import InvokeError + from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 8e8ccf2b90..7eccd59d17 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -2,6 +2,9 @@ from collections.abc import Generator, Mapping, Sequence from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING, Any, Union, final +from graphon.enums import NodeType +from graphon.file import File, FileUploadConfig +from graphon.variables.input_entities import VariableEntityType from sqlalchemy.orm import Session from core.app.apps.draft_variable_saver import ( @@ -13,9 +16,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope from extensions.ext_database import db from factories import file_factory -from graphon.enums import NodeType -from graphon.file import File, FileUploadConfig -from graphon.variables.input_entities import VariableEntityType from libs.orjson import orjson_dumps from models import Account, EndUser from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index d1771452c5..20bf81aeec 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -7,6 +7,7 @@ from enum import IntEnum, auto from typing import Any from cachetools import TTLCache, cachedmethod +from graphon.runtime import GraphRuntimeState from redis.exceptions import RedisError from sqlalchemy.orm import DeclarativeMeta @@ -21,7 +22,6 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from extensions.ext_redis import redis_client -from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 4a4c8b535d..4aebc0cb30 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,6 +5,17 @@ from collections.abc import Generator, Mapping, Sequence from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError + from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( @@ -30,21 +41,11 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.tools.tool_file_manager import ToolFileManager from extensions.ext_database import db -from graphon.file.enums import FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, Message, MessageAnnotation, MessageFile if TYPE_CHECKING: - from graphon.file.models import File + from graphon.file import File _logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index db3a98c7ac..b675a87382 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -6,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, copy_current_request_context, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeF from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account from models.model import App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 077c5239f3..050f763e95 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -16,8 +18,6 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index 2a90fbdad0..ab277857fe 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,9 +4,10 @@ from __future__ import annotations from typing import TYPE_CHECKING -from core.workflow.system_variables import SystemVariableKey, get_system_text from graphon.runtime import GraphRuntimeState +from core.workflow.system_variables import SystemVariableKey, get_system_text + if TYPE_CHECKING: from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index e4aa2ff650..a515531616 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -6,6 +6,19 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, NewType, TypedDict, Union +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import FILE_MODEL_IDENTITY, File +from graphon.runtime import GraphRuntimeState +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.variables import Variable +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from sqlalchemy.orm import Session @@ -55,19 +68,6 @@ from core.workflow.human_input_forms import load_form_tokens_by_form_id from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db -from graphon.entities.pause_reason import HumanInputRequired -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import FILE_MODEL_IDENTITY, File -from graphon.runtime import GraphRuntimeState -from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment -from graphon.variables.variables import Variable -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.human_input import HumanInputForm diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index c418fe9759..a62c5b80b5 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -6,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, copy_current_request_context, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from sqlalchemy import select @@ -23,7 +24,6 @@ from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, I from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account, App, EndUser, Message from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 6bb1ecdcb1..b216f7cf7b 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager @@ -14,8 +16,6 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 48457b5326..fa242003a2 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -10,6 +10,8 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, cast, overload from flask import Flask, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -41,8 +43,6 @@ from core.repositories.factory import ( WorkflowNodeExecutionRepository, ) from extensions.ext_database import db -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 44d2450f74..4c188dac68 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -2,6 +2,14 @@ import logging import time from typing import cast +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner @@ -18,13 +26,6 @@ from core.workflow.system_variables import build_bootstrap_variables, build_syst from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db -from graphon.entities.graph_init_params import GraphInitParams -from graphon.enums import WorkflowType -from graphon.graph import Graph -from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from models.dataset import Document, Pipeline from models.model import EndUser from models.workflow import Workflow diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 8ad6893a15..9618ab35c6 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -8,6 +8,10 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, Union, overload from flask import Flask, current_app +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -34,10 +38,6 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models.account import Account from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index c02c0b16e9..2cb8088971 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -3,6 +3,12 @@ import time from collections.abc import Sequence from typing import cast +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner @@ -15,11 +21,6 @@ from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels.redis_channel import RedisChannel -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader from libs.datetime_utils import naive_utc_now from models.workflow import Workflow diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index e0c5b44ee4..49af169e88 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -4,6 +4,9 @@ from collections.abc import Callable, Generator from contextlib import contextmanager from typing import Union +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -58,9 +61,6 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.system_variables import build_system_variables from extensions.ext_database import db -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus -from graphon.runtime import GraphRuntimeState from models import Account from models.enums import CreatorUserRole from models.model import EndUser diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index d7d3bd27de..f68c8e60b4 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -3,6 +3,40 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph import Graph +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import ( + GraphEngineEvent, + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunAgentLogEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunHumanInputFormFilledEvent, + NodeRunHumanInputFormTimeoutEvent, + NodeRunIterationFailedEvent, + NodeRunIterationNextEvent, + NodeRunIterationStartedEvent, + NodeRunIterationSucceededEvent, + NodeRunLoopFailedEvent, + NodeRunLoopNextEvent, + NodeRunLoopStartedEvent, + NodeRunLoopSucceededEvent, + NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from pydantic import ValidationError from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -44,40 +78,6 @@ from core.workflow.system_variables import ( from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph import Graph -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import ( - GraphEngineEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunAgentLogEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.graph_events.graph import GraphRunAbortedEvent -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index d8d851c505..0cdbb5f50a 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -2,13 +2,13 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import TYPE_CHECKING, Any, Optional +from graphon.file import File, FileUploadConfig +from graphon.model_runtime.entities.model_entities import AIModelEntity from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from graphon.file import File, FileUploadConfig -from graphon.model_runtime.entities.model_entities import AIModelEntity if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 63857bfff2..5e56341f89 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -3,14 +3,14 @@ from datetime import datetime from enum import StrEnum, auto from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import PauseReason +from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from graphon.entities.pause_reason import PauseReason -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk class QueueEvent(StrEnum): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 719027bd23..ba3b2e356f 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -2,14 +2,14 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.nodes.human_input.entities import FormInput, UserAction class AnnotationReplyAccount(BaseModel): diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index d59f5125e3..d2d2fea4fb 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,8 +1,9 @@ import logging +from graphon.model_runtime.entities.message_entities import PromptMessage + from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation -from graphon.model_runtime.entities.message_entities import PromptMessage logger = logging.getLogger(__name__) diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index eeb9abbbfa..e09869f5f8 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -9,10 +9,11 @@ scope updates that matter to chat applications. import logging +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent + from core.workflow.system_variables import SystemVariableKey, get_system_text from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 98e2257b1f..79a5442130 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,15 +1,14 @@ from dataclasses import dataclass from typing import Annotated, Literal, Self, TypeAlias +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from pydantic import BaseModel, Field from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.workflow.system_variables import SystemVariableKey, get_system_text -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events.base import GraphEngineEvent -from graphon.graph_events.graph import GraphRunPausedEvent from models.model import AppMode from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/core/app/layers/suspend_layer.py b/api/core/app/layers/suspend_layer.py index 172306f271..1a79a9f843 100644 --- a/api/core/app/layers/suspend_layer.py +++ b/api/core/app/layers/suspend_layer.py @@ -1,6 +1,5 @@ -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events.base import GraphEngineEvent -from graphon.graph_events.graph import GraphRunPausedEvent +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent class SuspendLayer(GraphEngineLayer): diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index fef12df504..8c8daf8712 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -3,10 +3,10 @@ import uuid from typing import ClassVar from apscheduler.schedulers.background import BackgroundScheduler # type: ignore - from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events.base import GraphEngineEvent +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent + from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index 781a0aa3d3..77c7bec67e 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -2,13 +2,12 @@ import logging from datetime import UTC, datetime from typing import Any, ClassVar +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from pydantic import TypeAdapter from core.db.session_factory import session_factory from core.workflow.system_variables import SystemVariableKey, get_system_text -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events.base import GraphEngineEvent -from graphon.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from models.enums import WorkflowTriggerStatus from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index c49c4eb0ac..278d0cb30b 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -2,15 +2,16 @@ from __future__ import annotations from typing import Any +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.llm.entities import ModelConfig +from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from graphon.nodes.llm.protocols import CredentialsProvider + from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.nodes.llm.entities import ModelConfig -from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from graphon.nodes.llm.protocols import CredentialsProvider class DifyCredentialsProvider: diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 65a3f39d64..63d2235358 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -1,3 +1,4 @@ +from graphon.model_runtime.entities.llm_entities import LLMUsage from sqlalchemy import update from sqlalchemy.orm import Session @@ -7,7 +8,6 @@ from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance from extensions.ext_database import db -from graphon.model_runtime.entities.llm_entities import LLMUsage from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType from models.provider_ids import ModelProviderID diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 9e688589db..10b9c36d3e 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,6 +1,7 @@ import logging import time +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from sqlalchemy import select from sqlalchemy.orm import Session @@ -17,7 +18,6 @@ from core.app.entities.task_entities import ( ) from core.errors.error import QuotaExceededError from core.moderation.output_moderation import ModerationRule, OutputModeration -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus from models.model import Message diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index cf9cb6d051..a410fac558 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -4,6 +4,13 @@ from collections.abc import Generator from threading import Thread from typing import Any, Union, cast +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -53,13 +60,6 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db -from graphon.file.enums import FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - TextPromptMessageContent, -) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py index 45f622c469..b23a33923b 100644 --- a/api/core/app/task_pipeline/message_file_utils.py +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -1,8 +1,9 @@ from typing import TypedDict -from core.tools.signature import sign_tool_file +from graphon.file import FileTransferMethod from graphon.file import helpers as file_helpers -from graphon.file.enums import FileTransferMethod + +from core.tools.signature import sign_tool_file from models.model import MessageFile, UploadFile MAX_TOOL_FILE_EXTENSION_LENGTH = 10 diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index aa5291bad5..8604235ef2 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -9,6 +9,10 @@ import urllib.parse from collections.abc import Generator from typing import TYPE_CHECKING, Literal +from graphon.file import FileTransferMethod +from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +from graphon.file.runtime import set_workflow_file_runtime + from configs import dify_config from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol from core.db.session_factory import session_factory @@ -16,12 +20,9 @@ from core.helper.ssrf_proxy import ssrf_proxy from core.tools.signature import sign_tool_file from core.workflow.file_reference import parse_file_reference from extensions.ext_storage import storage -from graphon.file.enums import FileTransferMethod -from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol -from graphon.file.runtime import set_workflow_file_runtime if TYPE_CHECKING: - from graphon.file.models import File + from graphon.file import File class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index 5666bf1191..48cabaf4d0 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -7,18 +7,17 @@ This layer centralizes model-quota deduction outside node implementations. import logging from typing import TYPE_CHECKING, cast, final +from graphon.enums import BuiltinNodeTypes +from graphon.graph_engine.entities.commands import AbortCommand, CommandType +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent +from graphon.nodes.base.node import Node from typing_extensions import override from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.llm import deduct_llm_quota, ensure_llm_quota_available from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from graphon.enums import BuiltinNodeTypes -from graphon.graph_engine.entities.commands import AbortCommand, CommandType -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase -from graphon.graph_events.node import NodeRunSucceededEvent -from graphon.nodes.base.node import Node if TYPE_CHECKING: from graphon.nodes.llm.node import LLMNode diff --git a/api/core/app/workflow/layers/observability.py b/api/core/app/workflow/layers/observability.py index 837bf7ff81..8565c3076c 100644 --- a/api/core/app/workflow/layers/observability.py +++ b/api/core/app/workflow/layers/observability.py @@ -11,6 +11,10 @@ import logging from dataclasses import dataclass from typing import cast, final +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node from opentelemetry import context as context_api from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context from typing_extensions import override @@ -24,10 +28,6 @@ from extensions.otel.parser import ( ToolNodeOTelParser, ) from extensions.otel.runtime import is_instrument_flag_enabled -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index e540733de2..ada065a943 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -14,13 +14,6 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, Union -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository -from core.workflow.system_variables import SystemVariableKey -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from graphon.entities import WorkflowExecution, WorkflowNodeExecution from graphon.enums import ( WorkflowExecutionStatus, @@ -28,7 +21,7 @@ from graphon.enums import ( WorkflowNodeExecutionStatus, WorkflowType, ) -from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import ( GraphEngineEvent, GraphRunAbortedEvent, @@ -45,6 +38,14 @@ from graphon.graph_events import ( NodeRunSucceededEvent, ) from graphon.node_events import NodeRunResult + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from libs.datetime_utils import naive_utc_now diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 9e3c187210..3d8a7a54f3 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -6,6 +6,9 @@ import re import threading from collections.abc import Iterable +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAgentMessageEvent, @@ -15,8 +18,6 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent -from graphon.model_runtime.entities.model_entities import ModelType class AudioTrunk: diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 8de5cb1690..6a07119244 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,7 +1,7 @@ import logging from collections.abc import Sequence -from sqlalchemy import select +from sqlalchemy import select, update from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom @@ -70,23 +70,21 @@ class DatasetIndexToolCallbackHandler: ) child_chunk = db.session.scalar(child_chunk_stmt) if child_chunk: - _ = ( - db.session.query(DocumentSegment) + db.session.execute( + update(DocumentSegment) .where(DocumentSegment.id == child_chunk.segment_id) - .update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False - ) + .values(hit_count=DocumentSegment.hit_count + 1) ) else: - query = db.session.query(DocumentSegment).where( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + conditions = [DocumentSegment.index_node_id == document.metadata["doc_id"]] if "dataset_id" in document.metadata: - query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + conditions.append(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + db.session.execute( + update(DocumentSegment).where(*conditions).values(hit_count=DocumentSegment.hit_count + 1) + ) db.session.commit() diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 8a9875e4d7..143d1e696b 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -3,6 +3,9 @@ from collections.abc import Generator from threading import Lock from typing import Any, cast +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType, get_file_type_by_mime_type +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from sqlalchemy import select import contexts @@ -28,11 +31,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager from core.workflow.file_reference import build_file_reference from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam from factories import file_factory -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import WorkflowNodeExecutionMetadataKey -from graphon.file import File, get_file_type_by_mime_type -from graphon.file.enums import FileTransferMethod, FileType -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from models.model import UploadFile from models.tools import ToolFile from services.datasource_provider_service import DatasourceProviderService diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 84dd653772..14d1af2e8b 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -1,10 +1,10 @@ from typing import Literal, Optional +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter from core.tools.entities.common_entities import I18nObject -from graphon.model_runtime.utils.encoders import jsonable_encoder class DatasourceApiEntity(BaseModel): diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 089b8b8e59..04f15dee31 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -2,10 +2,11 @@ import logging from collections.abc import Generator from mimetypes import guess_extension, guess_type +from graphon.file import File, FileTransferMethod, FileType + from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference -from graphon.file import File, FileTransferMethod, FileType from models.tools import ToolFile logger = logging.getLogger(__name__) diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py index 9d970d5db1..72f6590e68 100644 --- a/api/core/entities/execution_extra_content.py +++ b/api/core/entities/execution_extra_content.py @@ -3,9 +3,9 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from typing import Any, TypeAlias +from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field -from graphon.nodes.human_input.entities import FormInput, UserAction from models.execution_extra_content import ExecutionContentType diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index bfa4f56915..a440829b46 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -6,6 +6,7 @@ from enum import StrEnum from typing import TYPE_CHECKING, Any from urllib.parse import urlparse +from graphon.file import helpers as file_helpers from pydantic import BaseModel from configs import dify_config @@ -15,7 +16,6 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from graphon.file import helpers as file_helpers if TYPE_CHECKING: from models.tools import MCPToolProvider diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index e99a131500..84d95c38c6 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,11 +1,10 @@ from collections.abc import Sequence from enum import StrEnum, auto -from pydantic import BaseModel, ConfigDict - from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType, ProviderModel from graphon.model_runtime.entities.provider_entities import ProviderEntity +from pydantic import BaseModel, ConfigDict class ModelStatus(StrEnum): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index d90afd3f7b..8b48aa2660 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -7,6 +7,16 @@ from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.runtime import ModelRuntime from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -22,16 +32,6 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from graphon.model_runtime.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index dffc7f2fc1..2c8767a32b 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -3,6 +3,7 @@ from __future__ import annotations from enum import StrEnum, auto from typing import Union +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, ConfigDict, Field from core.entities.parameter_entities import ( @@ -12,7 +13,6 @@ from core.entities.parameter_entities import ( ToolSelectorScope, ) from core.tools.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType class ProviderQuotaType(StrEnum): diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 951e065b2c..35bfcfb6a5 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,6 +4,7 @@ from threading import Lock from typing import Any import httpx +from graphon.nodes.code.entities import CodeLanguage from pydantic import BaseModel from yarl import URL @@ -13,7 +14,6 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client -from graphon.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 17345dc203..20125ec6b3 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -19,7 +19,7 @@ def encrypt_token(tenant_id: str, token: str): from extensions.ext_database import db from models.account import Tenant - if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()): + if not (tenant := db.session.get(Tenant, tenant_id)): raise ValueError(f"Tenant with id {tenant_id} not found") assert tenant.encrypt_public_key is not None encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index dc37a36943..a1e782a094 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -2,13 +2,14 @@ import logging import secrets from typing import cast +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeBadRequestError +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from extensions.ext_hosting_provider import hosting_configuration -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeBadRequestError -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel from models.provider import ProviderType logger = logging.getLogger(__name__) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index eb762c3508..60f5434bc1 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,10 +1,10 @@ from flask import Flask +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel from configs import dify_config from core.entities import DEFAULT_PLUGIN_ID from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel -from graphon.model_runtime.entities.model_entities import ModelType class HostingQuota(BaseModel): diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 46bf1d6937..3ec17bc986 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -9,6 +9,7 @@ from collections.abc import Mapping from typing import Any from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm.exc import ObjectDeletedError @@ -34,7 +35,6 @@ from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage -from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models import Account diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 3712374305..d39630ad95 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -5,6 +5,12 @@ from collections.abc import Sequence from typing import Protocol, cast import json_repair +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from sqlalchemy import select from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload @@ -29,11 +35,6 @@ from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from extensions.ext_storage import storage -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models import App, Message, WorkflowNodeExecutionModel from models.workflow import Workflow @@ -410,8 +411,8 @@ class LLMGenerator: model_config: ModelConfig, ideal_output: str | None, ): - last_run: Message | None = ( - db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() + last_run: Message | None = db.session.scalar( + select(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).limit(1) ) if not last_run: return LLMGenerator.__instruction_modify_common( diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 81672ee7aa..a1710f11ac 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -5,11 +5,6 @@ from enum import StrEnum from typing import Any, Literal, cast, overload import json_repair -from pydantic import TypeAdapter, ValidationError - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT -from core.model_manager import ModelInstance from graphon.model_runtime.callbacks.base_callback import Callback from graphon.model_runtime.entities.llm_entities import ( LLMResult, @@ -26,6 +21,11 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule +from pydantic import TypeAdapter, ValidationError + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT +from core.model_manager import ModelInstance class ResponseFormat(StrEnum): diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 92d23c6dc9..27000c947c 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -3,11 +3,12 @@ import logging from collections.abc import Mapping from typing import Any, cast +from graphon.variables.input_entities import VariableEntity, VariableEntityType + from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types as mcp_types -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index 7b5a7635f1..7e35044176 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -4,11 +4,11 @@ from contextlib import AbstractContextManager import httpx import httpx_sse +from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError -from graphon.model_runtime.utils.encoders import jsonable_encoder HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 658206128d..09c84538a9 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,14 +1,5 @@ from collections.abc import Sequence -from sqlalchemy import select -from sqlalchemy.orm import sessionmaker - -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.file_access import DatabaseFileAccessController -from core.model_manager import ModelInstance -from core.prompt.utils.extract_thread_messages import extract_thread_messages -from extensions.ext_database import db -from factories import file_factory from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -19,6 +10,15 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.file_access import DatabaseFileAccessController +from core.model_manager import ModelInstance +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from extensions.ext_database import db +from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile from models.workflow import Workflow from repositories.api_workflow_run_repository import APIWorkflowRunRepository diff --git a/api/core/model_manager.py b/api/core/model_manager.py index f5ff375f65..87d1d7fba6 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -2,14 +2,6 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from typing import IO, Any, Literal, Optional, Union, cast, overload -from configs import dify_config -from core.entities.embedding_type import EmbeddingInputType -from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import ModelLoadBalancingConfiguration -from core.errors.error import ProviderTokenNotInitError -from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager -from core.provider_manager import ProviderManager -from extensions.ext_redis import redis_client from graphon.model_runtime.callbacks.base_callback import Callback from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -23,6 +15,15 @@ from graphon.model_runtime.model_providers.__base.rerank_model import RerankMode from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from graphon.model_runtime.model_providers.__base.tts_model import TTSModel + +from configs import dify_config +from core.entities.embedding_type import EmbeddingInputType +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import ModelLoadBalancingConfiguration +from core.errors.error import ProviderTokenNotInitError +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from core.provider_manager import ProviderManager +from extensions.ext_redis import redis_client from models.provider import ProviderType from services.enterprise.plugin_manager_service import PluginCredentialType diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 35d4469bc1..dd038c77f1 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,6 +1,7 @@ +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelManager from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult -from graphon.model_runtime.entities.model_entities import ModelType class OpenAIModeration(Moderation): diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 76e81242f4..70aaf2a07b 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -1,6 +1,8 @@ import logging from collections.abc import Sequence +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import SpanKind from sqlalchemy.orm import sessionmaker @@ -58,8 +60,6 @@ from core.ops.entities.trace_entity import ( ) from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/aliyun_trace/utils.py b/api/core/ops/aliyun_trace/utils.py index 43b204b78c..d8e105d6a3 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/core/ops/aliyun_trace/utils.py @@ -2,6 +2,8 @@ import json from collections.abc import Mapping from typing import Any +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, Status, StatusCode from core.ops.aliyun_trace.entities.semconv import ( @@ -15,8 +17,6 @@ from core.ops.aliyun_trace.entities.semconv import ( ) from core.rag.models.document import Document from extensions.ext_database import db -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser # Constants @@ -27,9 +27,7 @@ DEFAULT_FRAMEWORK_NAME = "dify" def get_user_id_from_message_data(message_data) -> str: user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: user_id = end_user_data.session_id return user_id diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index e354c3909a..39d97e2882 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -6,6 +6,7 @@ from datetime import datetime, timedelta from typing import Any, Union, cast from urllib.parse import urlparse +from graphon.enums import WorkflowNodeExecutionStatus from openinference.semconv.trace import ( MessageAttributes, OpenInferenceMimeTypeValues, @@ -300,7 +301,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): "app_name": node_execution.title, "status": node_execution.status, "status_message": node_execution.error or "", - "level": "ERROR" if node_execution.status == "failed" else "DEFAULT", + "level": "ERROR" if node_execution.status == WorkflowNodeExecutionStatus.FAILED else "DEFAULT", } ) @@ -361,7 +362,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", []))) node_span.set_attributes(llm_attributes) finally: - if node_execution.status == "failed": + if node_execution.status == WorkflowNodeExecutionStatus.FAILED: set_span_status(node_span, node_execution.error) else: set_span_status(node_span) @@ -409,9 +410,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): # Add end user data if available if trace_info.message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, trace_info.message_data.from_end_user_id) if end_user_data is not None: metadata["end_user_id"] = end_user_data.session_id diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 50a2cdea63..45b2f635ba 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator class BaseTraceInfo(BaseModel): message_id: str | None = None message_data: Any | None = None - inputs: Union[str, dict[str, Any], list] | None = None - outputs: Union[str, dict[str, Any], list] | None = None + inputs: Union[str, dict[str, Any], list[Any]] | None = None + outputs: Union[str, dict[str, Any], list[Any]] | None = None start_time: datetime | None = None end_time: datetime | None = None metadata: dict[str, Any] @@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel): @field_validator("inputs", "outputs") @classmethod - def ensure_type(cls, v): + def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None: if v is None: return None if isinstance(v, str | dict | list): @@ -27,6 +27,48 @@ class BaseTraceInfo(BaseModel): model_config = ConfigDict(protected_namespaces=()) + @property + def resolved_trace_id(self) -> str | None: + """Get trace_id with intelligent fallback. + + Priority: + 1. External trace_id (from X-Trace-Id header) + 2. workflow_run_id (if this trace type has it) + 3. message_id (as final fallback) + """ + if self.trace_id: + return self.trace_id + + # Try workflow_run_id (only exists on workflow-related traces) + workflow_run_id = getattr(self, "workflow_run_id", None) + if workflow_run_id: + return workflow_run_id + + # Final fallback to message_id + return str(self.message_id) if self.message_id else None + + @property + def resolved_parent_context(self) -> tuple[str | None, str | None]: + """Resolve cross-workflow parent linking from metadata. + + Extracts typed parent IDs from the untyped ``parent_trace_context`` + metadata dict (set by tool_node when invoking nested workflows). + + Returns: + (trace_correlation_override, parent_span_id_source) where + trace_correlation_override is the outer workflow_run_id and + parent_span_id_source is the outer node_execution_id. + """ + parent_ctx = self.metadata.get("parent_trace_context") + if not isinstance(parent_ctx, dict): + return None, None + trace_override = parent_ctx.get("parent_workflow_run_id") + parent_span = parent_ctx.get("parent_node_execution_id") + return ( + trace_override if isinstance(trace_override, str) else None, + parent_span if isinstance(parent_span, str) else None, + ) + @field_serializer("start_time", "end_time") def serialize_datetime(self, dt: datetime | None) -> str | None: if dt is None: @@ -48,7 +90,10 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_version: str error: str | None = None total_tokens: int + prompt_tokens: int | None = None + completion_tokens: int | None = None file_list: list[str] + invoked_by: str | None = None query: str metadata: dict[str, Any] @@ -59,7 +104,7 @@ class MessageTraceInfo(BaseTraceInfo): answer_tokens: int total_tokens: int error: str | None = None - file_list: Union[str, dict[str, Any], list] | None = None + file_list: Union[str, dict[str, Any], list[Any]] | None = None message_file_data: Any | None = None conversation_mode: str gen_ai_server_time_to_first_token: float | None = None @@ -106,7 +151,7 @@ class ToolTraceInfo(BaseTraceInfo): tool_config: dict[str, Any] time_cost: Union[int, float] tool_parameters: dict[str, Any] - file_url: Union[str, None, list] = None + file_url: Union[str, None, list[str]] = None class GenerateNameTraceInfo(BaseTraceInfo): @@ -114,6 +159,79 @@ class GenerateNameTraceInfo(BaseTraceInfo): tenant_id: str +class PromptGenerationTraceInfo(BaseTraceInfo): + """Trace information for prompt generation operations (rule-generate, code-generate, etc.).""" + + tenant_id: str + user_id: str + app_id: str | None = None + + operation_type: str + instruction: str + + prompt_tokens: int + completion_tokens: int + total_tokens: int + + model_provider: str + model_name: str + + latency: float + + total_price: float | None = None + currency: str | None = None + + error: str | None = None + + model_config = ConfigDict(protected_namespaces=()) + + +class WorkflowNodeTraceInfo(BaseTraceInfo): + workflow_id: str + workflow_run_id: str + tenant_id: str + node_execution_id: str + node_id: str + node_type: str + title: str + + status: str + error: str | None = None + elapsed_time: float + + index: int + predecessor_node_id: str | None = None + + total_tokens: int = 0 + total_price: float = 0.0 + currency: str | None = None + + model_provider: str | None = None + model_name: str | None = None + prompt_tokens: int | None = None + completion_tokens: int | None = None + + tool_name: str | None = None + + iteration_id: str | None = None + iteration_index: int | None = None + loop_id: str | None = None + loop_index: int | None = None + parallel_id: str | None = None + + node_inputs: Mapping[str, Any] | None = None + node_outputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + + invoked_by: str | None = None + + model_config = ConfigDict(protected_namespaces=()) + + +class DraftNodeExecutionTrace(WorkflowNodeTraceInfo): + pass + + class TaskData(BaseModel): app_id: str trace_info_type: str @@ -128,11 +246,31 @@ trace_info_info_map = { "DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo, "ToolTraceInfo": ToolTraceInfo, "GenerateNameTraceInfo": GenerateNameTraceInfo, + "PromptGenerationTraceInfo": PromptGenerationTraceInfo, + "WorkflowNodeTraceInfo": WorkflowNodeTraceInfo, + "DraftNodeExecutionTrace": DraftNodeExecutionTrace, } +class OperationType(StrEnum): + """Operation type for token metric labels. + + Used as a metric attribute on ``dify.tokens.input`` / ``dify.tokens.output`` + counters so consumers can break down token usage by operation. + """ + + WORKFLOW = "workflow" + NODE_EXECUTION = "node_execution" + MESSAGE = "message" + RULE_GENERATE = "rule_generate" + CODE_GENERATE = "code_generate" + STRUCTURED_OUTPUT = "structured_output" + INSTRUCTION_MODIFY = "instruction_modify" + + class TraceTaskName(StrEnum): CONVERSATION_TRACE = "conversation" + DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution" WORKFLOW_TRACE = "workflow" MESSAGE_TRACE = "message" MODERATION_TRACE = "moderation" @@ -140,4 +278,6 @@ class TraceTaskName(StrEnum): DATASET_RETRIEVAL_TRACE = "dataset_retrieval" TOOL_TRACE = "tool" GENERATE_NAME_TRACE = "generate_conversation_name" + PROMPT_GENERATION_TRACE = "prompt_generation" + NODE_EXECUTION_TRACE = "node_execution" DATASOURCE_TRACE = "datasource" diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 4a634e2e57..3644b6b4c2 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -2,6 +2,7 @@ import logging import os from datetime import datetime, timedelta +from graphon.enums import BuiltinNodeTypes from langfuse import Langfuse from sqlalchemy.orm import sessionmaker @@ -29,7 +30,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( from core.ops.utils import filter_none_values from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes from models import EndUser, WorkflowNodeExecutionTriggeredFrom from models.enums import MessageStatus @@ -241,9 +241,7 @@ class LangFuseDataTrace(BaseTraceInstance): user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: user_id = end_user_data.session_id metadata["user_id"] = user_id diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 9f7d73b4ca..490c64af84 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -4,6 +4,7 @@ import uuid from datetime import datetime, timedelta from typing import cast +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from langsmith import Client from langsmith.schemas import RunBase from sqlalchemy.orm import sessionmaker @@ -29,7 +30,6 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -259,9 +259,7 @@ class LangSmithDataTrace(BaseTraceInstance): metadata["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: end_user_id = end_user_data.session_id metadata["end_user_id"] = end_user_id diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/core/ops/mlflow_trace/mlflow_trace.py index 8ec69e3542..946d3cdd47 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/core/ops/mlflow_trace/mlflow_trace.py @@ -5,10 +5,12 @@ from datetime import datetime, timedelta from typing import Any, cast import mlflow +from graphon.enums import BuiltinNodeTypes from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey from mlflow.tracing.fluent import start_span_no_context, update_current_trace from mlflow.tracing.provider import detach_span_from_context, set_span_in_context +from sqlalchemy import select from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig @@ -24,7 +26,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes from models import EndUser from models.workflow import WorkflowNodeExecutionModel @@ -320,7 +321,7 @@ class MLflowDataTrace(BaseTraceInstance): def _get_message_user_id(self, metadata: dict) -> str | None: if (end_user_id := metadata.get("from_end_user_id")) and ( - end_user_data := db.session.query(EndUser).where(EndUser.id == end_user_id).first() + end_user_data := db.session.get(EndUser, end_user_id) ): return end_user_data.session_id @@ -447,25 +448,11 @@ class MLflowDataTrace(BaseTraceInstance): def _get_workflow_nodes(self, workflow_run_id: str): """Helper method to get workflow nodes""" - workflow_nodes = ( - db.session.query( - WorkflowNodeExecutionModel.id, - WorkflowNodeExecutionModel.tenant_id, - WorkflowNodeExecutionModel.app_id, - WorkflowNodeExecutionModel.title, - WorkflowNodeExecutionModel.node_type, - WorkflowNodeExecutionModel.status, - WorkflowNodeExecutionModel.inputs, - WorkflowNodeExecutionModel.outputs, - WorkflowNodeExecutionModel.created_at, - WorkflowNodeExecutionModel.elapsed_time, - WorkflowNodeExecutionModel.process_data, - WorkflowNodeExecutionModel.execution_metadata, - ) - .filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) + workflow_nodes = db.session.scalars( + select(WorkflowNodeExecutionModel) + .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) .order_by(WorkflowNodeExecutionModel.created_at) - .all() - ) + ).all() return workflow_nodes def _get_node_span_type(self, node_type: str) -> str: diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index a3ead548bb..2215bdeb33 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -5,6 +5,7 @@ import uuid from datetime import datetime, timedelta from typing import cast +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 from sqlalchemy.orm import sessionmaker @@ -24,7 +25,6 @@ from core.ops.entities.trace_entity import ( ) from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -288,9 +288,7 @@ class OpikDataTrace(BaseTraceInstance): metadata["file_list"] = file_list if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: end_user_id = end_user_data.session_id metadata["end_user_id"] = end_user_id diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 87a7579f3a..9c36d57c6f 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -15,22 +15,32 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token -from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum +from core.ops.entities.config_entity import ( + OPS_FILE_PATH, + TracingProviderEnum, +) from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, GenerateNameTraceInfo, MessageTraceInfo, ModerationTraceInfo, + PromptGenerationTraceInfo, SuggestedQuestionTraceInfo, TaskData, ToolTraceInfo, TraceTaskName, + WorkflowNodeTraceInfo, WorkflowTraceInfo, ) from core.ops.utils import get_message_data +from extensions.ext_database import db from extensions.ext_storage import storage -from models.engine import db +from models.account import Tenant +from models.dataset import Dataset from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig +from models.provider import Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderType +from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from models.workflow import WorkflowAppLog from tasks.ops_trace_task import process_trace_tasks @@ -40,9 +50,144 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]: + """Return (app_name, workspace_name) for the given IDs. Falls back to empty strings.""" + app_name = "" + workspace_name = "" + if not app_id and not tenant_id: + return app_name, workspace_name + with Session(db.engine) as session: + if app_id: + name = session.scalar(select(App.name).where(App.id == app_id)) + if name: + app_name = name + if tenant_id: + name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id)) + if name: + workspace_name = name + return app_name, workspace_name + + +_PROVIDER_TYPE_TO_MODEL: dict[str, type] = { + "builtin": BuiltinToolProvider, + "plugin": BuiltinToolProvider, + "api": ApiToolProvider, + "workflow": WorkflowToolProvider, + "mcp": MCPToolProvider, +} + + +def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str: + if not credential_id: + return "" + model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "") + if not model_cls: + return "" + with Session(db.engine) as session: + name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id)) # type: ignore[attr-defined] + return str(name) if name else "" + + +def _lookup_llm_credential_info( + tenant_id: str | None, provider: str | None, model: str | None, model_type: str | None = "llm" +) -> tuple[str | None, str]: + """ + Lookup LLM credential ID and name for the given provider and model. + Returns (credential_id, credential_name). + + Handles async timing issues gracefully - if credential is deleted between lookups, + returns the ID but empty name rather than failing. + """ + if not tenant_id or not provider: + return None, "" + + try: + with Session(db.engine) as session: + # Try to find provider-level or model-level configuration + provider_record = session.scalar( + select(Provider).where( + Provider.tenant_id == tenant_id, + Provider.provider_name == provider, + Provider.provider_type == ProviderType.CUSTOM, + ) + ) + + if not provider_record: + return None, "" + + # Check if there's a model-specific config + credential_id = None + credential_name = "" + is_model_level = False + + if model: + # Try model-level first + model_record = session.scalar( + select(ProviderModel).where( + ProviderModel.tenant_id == tenant_id, + ProviderModel.provider_name == provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type, + ) + ) + + if model_record and model_record.credential_id: + credential_id = model_record.credential_id + is_model_level = True + + if not credential_id and provider_record.credential_id: + # Fall back to provider-level credential + credential_id = provider_record.credential_id + is_model_level = False + + # Lookup credential_name if we have credential_id + if credential_id: + try: + if is_model_level: + # Query ProviderModelCredential + cred_name = session.scalar( + select(ProviderModelCredential.credential_name).where( + ProviderModelCredential.id == credential_id + ) + ) + else: + # Query ProviderCredential + cred_name = session.scalar( + select(ProviderCredential.credential_name).where(ProviderCredential.id == credential_id) + ) + + if cred_name: + credential_name = str(cred_name) + except Exception as e: + # Credential might have been deleted between lookups (async timing) + # Return ID but empty name rather than failing + logger.warning( + "Failed to lookup credential name for credential_id=%s (provider=%s, model=%s): %s", + credential_id, + provider, + model, + str(e), + exc_info=True, + ) + + return credential_id, credential_name + except Exception as e: + # Database query failed or other unexpected error + # Return empty rather than propagating error to telemetry emission + logger.warning( + "Failed to lookup LLM credential info for tenant_id=%s, provider=%s, model=%s: %s", + tenant_id, + provider, + model, + str(e), + exc_info=True, + ) + return None, "" + + class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): - def __getitem__(self, key: str) -> dict[str, Any]: - match key: + def __getitem__(self, provider: str) -> dict[str, Any]: + match provider: case TracingProviderEnum.LANGFUSE: from core.ops.entities.config_entity import LangfuseConfig from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace @@ -149,7 +294,7 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): } case _: - raise KeyError(f"Unsupported tracing provider: {key}") + raise KeyError(f"Unsupported tracing provider: {provider}") provider_config_map = OpsTraceProviderConfigMap() @@ -275,10 +420,10 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig | None = ( - db.session.query(TraceAppConfig) + trace_config_data: TraceAppConfig | None = db.session.scalar( + select(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) - .first() + .limit(1) ) if not trace_config_data: @@ -314,7 +459,11 @@ class OpsTraceManager: if app_id is None: return None - app: App | None = db.session.query(App).where(App.id == app_id).first() + # Handle storage_id format (tenant-{uuid}) - not a real app_id + if isinstance(app_id, str) and app_id.startswith("tenant-"): + return None + + app = db.session.get(App, app_id) if app is None: return None @@ -388,7 +537,7 @@ class OpsTraceManager: except KeyError: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: App | None = db.session.query(App).where(App.id == app_id).first() + app_config: App | None = db.session.get(App, app_id) if not app_config: raise ValueError("App not found") app_config.tracing = json.dumps( @@ -406,7 +555,7 @@ class OpsTraceManager: :param app_id: app id :return: """ - app: App | None = db.session.query(App).where(App.id == app_id).first() + app: App | None = db.session.get(App, app_id) if not app: raise ValueError("App not found") if not app.tracing: @@ -466,8 +615,6 @@ class TraceTask: @classmethod def _get_workflow_run_repo(cls): - from repositories.factory import DifyAPIRepositoryFactory - if cls._workflow_run_repo is None: with cls._repo_lock: if cls._workflow_run_repo is None: @@ -478,6 +625,77 @@ class TraceTask: cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) return cls._workflow_run_repo + @classmethod + def _calculate_workflow_token_split( + cls, session: "Session", workflow_run_id: str, tenant_id: str + ) -> tuple[int, int]: + """Sum prompt/completion tokens across all node executions for a workflow run. + + Reads from the ``outputs`` column (where LLM nodes store ``usage.prompt_tokens`` + and ``usage.completion_tokens``) rather than ``execution_metadata``, which only + carries ``total_tokens``. Projects only the ``outputs`` column to avoid loading + large JSON blobs unnecessarily. + """ + import json + + from models.workflow import WorkflowNodeExecutionModel + + rows = ( + session.execute( + select(WorkflowNodeExecutionModel.outputs).where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + ) + ) + .scalars() + .all() + ) + + total_prompt = 0 + total_completion = 0 + + for raw in rows: + if not raw: + continue + try: + outputs = json.loads(raw) if isinstance(raw, str) else raw + except (ValueError, TypeError): + continue + if not isinstance(outputs, dict): + continue + usage = outputs.get("usage") + if not isinstance(usage, dict): + continue + prompt = usage.get("prompt_tokens") + if isinstance(prompt, (int, float)): + total_prompt += int(prompt) + completion = usage.get("completion_tokens") + if isinstance(completion, (int, float)): + total_completion += int(completion) + + return (total_prompt, total_completion) + + @classmethod + def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str: + """Extract user ID from metadata, prioritizing end_user over account. + + Returns the actual user ID (end_user or account) who invoked the workflow, + regardless of invoke_from context. + """ + # Priority 1: End user (external users via API/WebApp) + if user_id := metadata.get("from_end_user_id"): + return f"end_user:{user_id}" + + # Priority 2: Account user (internal users via console/debugger) + if user_id := metadata.get("from_account_id"): + return f"account:{user_id}" + + # Priority 3: User (internal users via console/debugger) + if user_id := metadata.get("user_id"): + return f"user:{user_id}" + + return "anonymous" + def __init__( self, trace_type: Any, @@ -491,6 +709,7 @@ class TraceTask: self.trace_type = trace_type self.message_id = message_id self.workflow_run_id = workflow_execution.id_ if workflow_execution else None + self.workflow_total_tokens: int | None = workflow_execution.total_tokens if workflow_execution else None self.conversation_id = conversation_id self.user_id = user_id self.timer = timer @@ -498,6 +717,8 @@ class TraceTask: self.app_id = None self.trace_id = None self.kwargs = kwargs + if user_id is not None and "user_id" not in self.kwargs: + self.kwargs["user_id"] = user_id external_trace_id = kwargs.get("external_trace_id") if external_trace_id: self.trace_id = external_trace_id @@ -509,9 +730,12 @@ class TraceTask: preprocess_map = { TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs), TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace( - workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id + workflow_run_id=self.workflow_run_id, + conversation_id=self.conversation_id, + user_id=self.user_id, + total_tokens_override=self.workflow_total_tokens, ), - TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id), + TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs), TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( message_id=self.message_id, timer=self.timer, **self.kwargs ), @@ -527,6 +751,9 @@ class TraceTask: TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( conversation_id=self.conversation_id, timer=self.timer, **self.kwargs ), + TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs), + TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs), + TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs), } return preprocess_map.get(self.trace_type, lambda: None)() @@ -541,6 +768,7 @@ class TraceTask: workflow_run_id: str | None, conversation_id: str | None, user_id: str | None, + total_tokens_override: int | None = None, ): if not workflow_run_id: return {} @@ -560,7 +788,7 @@ class TraceTask: workflow_run_version = workflow_run.version error = workflow_run.error or "" - total_tokens = workflow_run.total_tokens + total_tokens = total_tokens_override if total_tokens_override is not None else workflow_run.total_tokens file_list = workflow_run_inputs.get("sys.file") or [] query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" @@ -581,8 +809,18 @@ class TraceTask: Message.workflow_run_id == workflow_run_id, ) message_id = session.scalar(message_data_stmt) + prompt_tokens, completion_tokens = self._calculate_workflow_token_split( + session, workflow_run_id=workflow_run_id, tenant_id=tenant_id + ) - metadata = { + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + + metadata: dict[str, Any] = { "workflow_id": workflow_id, "conversation_id": conversation_id, "workflow_run_id": workflow_run_id, @@ -595,8 +833,14 @@ class TraceTask: "triggered_from": workflow_run.triggered_from, "user_id": user_id, "app_id": workflow_run.app_id, + "app_name": app_name, + "workspace_name": workspace_name, } + parent_trace_context = self.kwargs.get("parent_trace_context") + if parent_trace_context: + metadata["parent_trace_context"] = parent_trace_context + workflow_trace_info = WorkflowTraceInfo( trace_id=self.trace_id, workflow_data=workflow_run.to_dict(), @@ -611,6 +855,8 @@ class TraceTask: workflow_run_version=workflow_run_version, error=error, total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, file_list=file_list, query=query, metadata=metadata, @@ -618,10 +864,11 @@ class TraceTask: message_id=message_id, start_time=workflow_run.created_at, end_time=workflow_run.finished_at, + invoked_by=self._get_user_id_from_metadata(metadata), ) return workflow_trace_info - def message_trace(self, message_id: str | None): + def message_trace(self, message_id: str | None, **kwargs): if not message_id: return {} message_data = get_message_data(message_id) @@ -636,7 +883,7 @@ class TraceTask: inputs = message_data.message # get message file data - message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() + message_file_data = db.session.scalar(select(MessageFile).where(MessageFile.message_id == message_id).limit(1)) file_list = [] if message_file_data and message_file_data.url is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" @@ -644,6 +891,19 @@ class TraceTask: streaming_metrics = self._extract_streaming_metrics(message_data) + tenant_id = "" + with Session(db.engine) as session: + tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id)) + if tid: + tenant_id = str(tid) + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + metadata = { "conversation_id": message_data.conversation_id, "ls_provider": message_data.model_provider, @@ -655,7 +915,14 @@ class TraceTask: "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, "message_id": message_id, + "tenant_id": tenant_id, + "app_id": message_data.app_id, + "user_id": message_data.from_end_user_id or message_data.from_account_id, + "app_name": app_name, + "workspace_name": workspace_name, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id message_tokens = message_data.message_tokens @@ -672,7 +939,9 @@ class TraceTask: outputs=message_data.answer, file_list=file_list, start_time=created_at, - end_time=created_at + timedelta(seconds=message_data.provider_response_latency), + end_time=message_data.updated_at + if message_data.updated_at and message_data.updated_at > created_at + else created_at + timedelta(seconds=message_data.provider_response_latency), metadata=metadata, message_file_data=message_file_data, conversation_mode=conversation_mode, @@ -697,12 +966,14 @@ class TraceTask: "preset_response": moderation_result.preset_response, "query": moderation_result.query, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = ( - db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + workflow_app_log_data = db.session.scalar( + select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == message_data.workflow_run_id).limit(1) ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None @@ -738,12 +1009,14 @@ class TraceTask: "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = ( - db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + workflow_app_log_data = db.session.scalar( + select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == message_data.workflow_run_id).limit(1) ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None @@ -777,6 +1050,52 @@ class TraceTask: if not message_data: return {} + tenant_id = "" + with Session(db.engine) as session: + tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id)) + if tid: + tenant_id = str(tid) + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + + doc_list = [doc.model_dump() for doc in documents] if documents else [] + dataset_ids: set[str] = set() + for doc in doc_list: + doc_meta = doc.get("metadata") or {} + did = doc_meta.get("dataset_id") + if did: + dataset_ids.add(did) + + embedding_models: dict[str, dict[str, str]] = {} + if dataset_ids: + with Session(db.engine) as session: + rows = session.execute( + select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where( + Dataset.id.in_(list(dataset_ids)) + ) + ).all() + for row in rows: + embedding_models[str(row[0])] = { + "embedding_model": row[1] or "", + "embedding_model_provider": row[2] or "", + } + + # Extract rerank model info from retrieval_model kwargs + rerank_model_provider = "" + rerank_model_name = "" + if "retrieval_model" in kwargs: + retrieval_model = kwargs["retrieval_model"] + if isinstance(retrieval_model, dict): + reranking_model = retrieval_model.get("reranking_model") + if isinstance(reranking_model, dict): + rerank_model_provider = reranking_model.get("reranking_provider_name", "") + rerank_model_name = reranking_model.get("reranking_model_name", "") + metadata = { "message_id": message_id, "ls_provider": message_data.model_provider, @@ -787,13 +1106,23 @@ class TraceTask: "agent_based": message_data.agent_based, "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, + "tenant_id": tenant_id, + "app_id": message_data.app_id, + "user_id": message_data.from_end_user_id or message_data.from_account_id, + "app_name": app_name, + "workspace_name": workspace_name, + "embedding_models": embedding_models, + "rerank_model_provider": rerank_model_provider, + "rerank_model_name": rerank_model_name, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( trace_id=self.trace_id, message_id=message_id, inputs=message_data.query or message_data.inputs, - documents=[doc.model_dump() for doc in documents] if documents else [], + documents=doc_list, start_time=timer.get("start"), end_time=timer.get("end"), metadata=metadata, @@ -836,9 +1165,13 @@ class TraceTask: "error": error, "tool_parameters": tool_parameters, } + if message_data.workflow_run_id: + metadata["workflow_run_id"] = message_data.workflow_run_id + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id file_url = "" - message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() + message_file_data = db.session.scalar(select(MessageFile).where(MessageFile.message_id == message_id).limit(1)) if message_file_data: message_file_id = message_file_data.id if message_file_data else None type = message_file_data.type @@ -890,6 +1223,8 @@ class TraceTask: "conversation_id": conversation_id, "tenant_id": tenant_id, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id generate_name_trace_info = GenerateNameTraceInfo( trace_id=self.trace_id, @@ -904,6 +1239,182 @@ class TraceTask: return generate_name_trace_info + def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict: + tenant_id = kwargs.get("tenant_id", "") + user_id = kwargs.get("user_id", "") + app_id = kwargs.get("app_id") + operation_type = kwargs.get("operation_type", "") + instruction = kwargs.get("instruction", "") + generated_output = kwargs.get("generated_output", "") + + prompt_tokens = kwargs.get("prompt_tokens", 0) + completion_tokens = kwargs.get("completion_tokens", 0) + total_tokens = kwargs.get("total_tokens", 0) + + model_provider = kwargs.get("model_provider", "") + model_name = kwargs.get("model_name", "") + + latency = kwargs.get("latency", 0.0) + + timer = kwargs.get("timer") + start_time = timer.get("start") if timer else None + end_time = timer.get("end") if timer else None + + total_price = kwargs.get("total_price") + currency = kwargs.get("currency") + + error = kwargs.get("error") + + app_name = None + workspace_name = None + if app_id: + app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id) + + metadata = { + "tenant_id": tenant_id, + "user_id": user_id, + "app_id": app_id or "", + "app_name": app_name, + "workspace_name": workspace_name, + "operation_type": operation_type, + "model_provider": model_provider, + "model_name": model_name, + } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id + + return PromptGenerationTraceInfo( + trace_id=self.trace_id, + inputs=instruction, + outputs=generated_output, + start_time=start_time, + end_time=end_time, + metadata=metadata, + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type=operation_type, + instruction=instruction, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + model_provider=model_provider, + model_name=model_name, + latency=latency, + total_price=total_price, + currency=currency, + error=error, + ) + + def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict: + node_data: dict = kwargs.get("node_execution_data", {}) + if not node_data: + return {} + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names( + node_data.get("app_id"), node_data.get("tenant_id") + ) + else: + app_name, workspace_name = "", "" + + # Try tool credential lookup first + credential_id = node_data.get("credential_id") + if is_enterprise_telemetry_enabled(): + credential_name = _lookup_credential_name(credential_id, node_data.get("credential_provider_type")) + # If no credential_id found (e.g., LLM nodes), try LLM credential lookup + if not credential_id: + llm_cred_id, llm_cred_name = _lookup_llm_credential_info( + tenant_id=node_data.get("tenant_id"), + provider=node_data.get("model_provider"), + model=node_data.get("model_name"), + model_type="llm", + ) + if llm_cred_id: + credential_id = llm_cred_id + credential_name = llm_cred_name + else: + credential_name = "" + metadata: dict[str, Any] = { + "tenant_id": node_data.get("tenant_id"), + "app_id": node_data.get("app_id"), + "app_name": app_name, + "workspace_name": workspace_name, + "user_id": node_data.get("user_id"), + "invoke_from": node_data.get("invoke_from"), + "credential_id": credential_id, + "credential_name": credential_name, + "dataset_ids": node_data.get("dataset_ids"), + "dataset_names": node_data.get("dataset_names"), + "plugin_name": node_data.get("plugin_name"), + } + + parent_trace_context = node_data.get("parent_trace_context") + if parent_trace_context: + metadata["parent_trace_context"] = parent_trace_context + + message_id: str | None = None + conversation_id = node_data.get("conversation_id") + workflow_execution_id = node_data.get("workflow_execution_id") + if conversation_id and workflow_execution_id and not parent_trace_context: + with Session(db.engine) as session: + msg_id = session.scalar( + select(Message.id).where( + Message.conversation_id == conversation_id, + Message.workflow_run_id == workflow_execution_id, + ) + ) + if msg_id: + message_id = str(msg_id) + metadata["message_id"] = message_id + if conversation_id: + metadata["conversation_id"] = conversation_id + + return WorkflowNodeTraceInfo( + trace_id=self.trace_id, + message_id=message_id, + start_time=node_data.get("created_at"), + end_time=node_data.get("finished_at"), + metadata=metadata, + workflow_id=node_data.get("workflow_id", ""), + workflow_run_id=node_data.get("workflow_execution_id", ""), + tenant_id=node_data.get("tenant_id", ""), + node_execution_id=node_data.get("node_execution_id", ""), + node_id=node_data.get("node_id", ""), + node_type=node_data.get("node_type", ""), + title=node_data.get("title", ""), + status=node_data.get("status", ""), + error=node_data.get("error"), + elapsed_time=node_data.get("elapsed_time", 0.0), + index=node_data.get("index", 0), + predecessor_node_id=node_data.get("predecessor_node_id"), + total_tokens=node_data.get("total_tokens", 0), + total_price=node_data.get("total_price", 0.0), + currency=node_data.get("currency"), + model_provider=node_data.get("model_provider"), + model_name=node_data.get("model_name"), + prompt_tokens=node_data.get("prompt_tokens"), + completion_tokens=node_data.get("completion_tokens"), + tool_name=node_data.get("tool_name"), + iteration_id=node_data.get("iteration_id"), + iteration_index=node_data.get("iteration_index"), + loop_id=node_data.get("loop_id"), + loop_index=node_data.get("loop_index"), + parallel_id=node_data.get("parallel_id"), + node_inputs=node_data.get("node_inputs"), + node_outputs=node_data.get("node_outputs"), + process_data=node_data.get("process_data"), + invoked_by=self._get_user_id_from_metadata(metadata), + ) + + def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict: + node_trace = self.node_execution_trace(**kwargs) + if not isinstance(node_trace, WorkflowNodeTraceInfo): + return node_trace + return DraftNodeExecutionTrace(**node_trace.model_dump()) + def _extract_streaming_metrics(self, message_data) -> dict: if not message_data.message_metadata: return {} @@ -937,13 +1448,17 @@ class TraceQueueManager: self.user_id = user_id self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) self.flask_app = current_app._get_current_object() # type: ignore + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled() if trace_manager_timer is None: self.start_timer() def add_trace_task(self, trace_task: TraceTask): global trace_manager_timer, trace_manager_queue try: - if self.trace_instance: + if self._enterprise_telemetry_enabled or self.trace_instance: trace_task.app_id = self.app_id trace_manager_queue.put(trace_task) except Exception: @@ -979,20 +1494,27 @@ class TraceQueueManager: def send_to_celery(self, tasks: list[TraceTask]): with self.flask_app.app_context(): for task in tasks: - if task.app_id is None: - continue + storage_id = task.app_id + if storage_id is None: + tenant_id = task.kwargs.get("tenant_id") + if tenant_id: + storage_id = f"tenant-{tenant_id}" + else: + logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type) + continue + file_id = uuid4().hex trace_info = task.execute() task_data = TaskData( - app_id=task.app_id, + app_id=storage_id, trace_info_type=type(trace_info).__name__, trace_info=trace_info.model_dump() if trace_info else None, ) - file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json" + file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json" storage.save(file_path, task_data.model_dump_json().encode("utf-8")) file_info = { "file_id": file_id, - "app_id": task.app_id, + "app_id": storage_id, } process_trace_tasks.delay(file_info) # type: ignore diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/core/ops/tencent_trace/span_builder.py index 4f06458157..f79095d966 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/core/ops/tencent_trace/span_builder.py @@ -6,6 +6,8 @@ import json import logging from datetime import datetime +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import Status, StatusCode from core.ops.entities.trace_entity import ( @@ -41,11 +43,6 @@ from core.ops.tencent_trace.entities.semconv import ( from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.utils import TencentTraceUtils from core.rag.models.document import Document -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index 1b1b1025bc..2bd6db22bf 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -4,6 +4,10 @@ Tencent APM tracing implementation with separated concerns import logging +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, +) +from graphon.nodes import BuiltinNodeTypes from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -25,10 +29,6 @@ from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from extensions.ext_database import db -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, -) -from graphon.nodes import BuiltinNodeTypes from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index a55505822a..8d9ba4694d 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -6,6 +6,7 @@ from typing import Any, cast import wandb import weave +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from sqlalchemy.orm import sessionmaker from weave.trace_server.trace_server_interface import ( CallEndReq, @@ -32,7 +33,6 @@ from core.ops.entities.trace_entity import ( from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -245,9 +245,7 @@ class WeaveDataTrace(BaseTraceInstance): attributes["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: end_user_id = end_user_data.session_id attributes["end_user_id"] = end_user_id diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 60d08b26c9..be11d2223c 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -227,7 +227,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): get app """ try: - app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first() + app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1)) except Exception: raise ValueError("app not found") diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 85625fc87d..c715b9171c 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -2,6 +2,20 @@ import tempfile from binascii import hexlify, unhexlify from collections.abc import Generator +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.llm import deduct_llm_quota from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelManager @@ -18,19 +32,6 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from graphon.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelType from models.account import Tenant diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 248f8ef3e6..9478997494 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,8 +1,5 @@ -from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from graphon.enums import BuiltinNodeTypes -from graphon.nodes.parameter_extractor.entities import ( - ModelConfig as ParameterExtractorModelConfig, -) +from graphon.nodes.llm.entities import ModelConfig as LLMModelConfig from graphon.nodes.parameter_extractor.entities import ( ParameterConfig, ParameterExtractorNodeData, @@ -11,9 +8,8 @@ from graphon.nodes.question_classifier.entities import ( ClassConfig, QuestionClassifierNodeData, ) -from graphon.nodes.question_classifier.entities import ( - ModelConfig as QuestionClassifierModelConfig, -) + +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from services.workflow_service import WorkflowService @@ -24,7 +20,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): tenant_id: str, user_id: str, parameters: list[ParameterConfig], - model_config: ParameterExtractorModelConfig, + model_config: LLMModelConfig, instruction: str, query: str, ): @@ -74,7 +70,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): cls, tenant_id: str, user_id: str, - model_config: QuestionClassifierModelConfig, + model_config: LLMModelConfig, classes: list[ClassConfig], instruction: str, query: str, diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index 1bd239a831..2177e8af90 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -1,10 +1,10 @@ +from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, Field, computed_field, model_validator from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.plugin.entities.plugin import PluginResourceRequirements from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity -from graphon.model_runtime.entities.provider_entities import ProviderEntity class MarketplacePluginDeclaration(BaseModel): diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 6aefc41400..b095b4998d 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -3,6 +3,7 @@ from collections.abc import Mapping from enum import StrEnum, auto from typing import Any +from graphon.model_runtime.entities.provider_entities import ProviderEntity from packaging.version import InvalidVersion, Version from pydantic import BaseModel, Field, field_validator, model_validator @@ -13,7 +14,6 @@ from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity from core.trigger.entities.entities import TriggerProviderEntity -from graphon.model_runtime.entities.provider_entities import ProviderEntity class PluginInstallationSource(StrEnum): diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 864e4b8dd7..94263ec44e 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -6,6 +6,8 @@ from datetime import datetime from enum import StrEnum from typing import Any, Generic, TypeVar +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, ConfigDict, Field from core.agent.plugin_entities import AgentProviderEntityWithPlugin @@ -16,8 +18,6 @@ from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.trigger.entities.entities import TriggerProviderEntity -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.provider_entities import ProviderEntity T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 704cacae2a..059f3fa9be 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -4,10 +4,6 @@ from collections.abc import Mapping from typing import Any, Literal from flask import Response -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from core.entities.provider_entities import BasicProviderConfig -from core.plugin.utils.http_parser import deserialize_response from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -18,18 +14,17 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.model_entities import ModelType -from graphon.nodes.parameter_extractor.entities import ( - ModelConfig as ParameterExtractorModelConfig, -) +from graphon.nodes.llm.entities import ModelConfig as LLMModelConfig from graphon.nodes.parameter_extractor.entities import ( ParameterConfig, ) from graphon.nodes.question_classifier.entities import ( ClassConfig, ) -from graphon.nodes.question_classifier.entities import ( - ModelConfig as QuestionClassifierModelConfig, -) +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from core.entities.provider_entities import BasicProviderConfig +from core.plugin.utils.http_parser import deserialize_response class InvokeCredentials(BaseModel): @@ -176,7 +171,7 @@ class RequestInvokeParameterExtractorNode(BaseModel): """ parameters: list[ParameterConfig] - model: ParameterExtractorModelConfig + model: LLMModelConfig instruction: str query: str @@ -187,7 +182,7 @@ class RequestInvokeQuestionClassifierNode(BaseModel): """ query: str - model: QuestionClassifierModelConfig + model: LLMModelConfig classes: list[ClassConfig] instruction: str diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index f6580d3707..2d0ab3fcd7 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -5,6 +5,14 @@ from collections.abc import Callable, Generator from typing import Any, TypeVar, cast import httpx +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from yarl import URL @@ -13,6 +21,7 @@ from core.plugin.endpoint.exc import EndpointSetupFailedError from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError from core.plugin.impl.exc import ( PluginDaemonBadRequestError, + PluginDaemonClientSideError, PluginDaemonInternalServerError, PluginDaemonNotFoundError, PluginDaemonUnauthorizedError, @@ -27,14 +36,6 @@ from core.trigger.errors import ( TriggerPluginInvokeError, TriggerProviderCredentialValidationError, ) -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( @@ -235,7 +236,10 @@ class BasePluginClient: response.raise_for_status() except httpx.HTTPStatusError as e: logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path) - raise e + if e.response.status_code < 500: + raise PluginDaemonClientSideError(description=str(e)) + else: + raise PluginDaemonInternalServerError(description=str(e)) except Exception as e: msg = f"Failed to request plugin daemon, url: {path}" logger.exception("Failed to request plugin daemon, url: %s", path) diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index c91fa71374..1e38c24717 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -2,6 +2,13 @@ import binascii from collections.abc import Generator, Sequence from typing import IO, Any +from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.utils.encoders import jsonable_encoder + from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDaemonInnerError, @@ -13,12 +20,6 @@ from core.plugin.entities.plugin_daemon import ( PluginVoicesResponse, ) from core.plugin.impl.base import BasePluginClient -from graphon.model_runtime.entities.llm_entities import LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult -from graphon.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index e3fba4ef3a..22c846b6de 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -6,6 +6,13 @@ from collections.abc import Generator, Iterable, Sequence from threading import Lock from typing import IO, Any, Union +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult +from graphon.model_runtime.runtime import ModelRuntime from pydantic import ValidationError from redis import RedisError @@ -14,13 +21,6 @@ from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.model import PluginModelClient from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.model_runtime.entities.provider_entities import ProviderEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult -from graphon.model_runtime.runtime import ModelRuntime from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py index 35abd2ae8c..4b29a6fc56 100644 --- a/api/core/plugin/impl/model_runtime_factory.py +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -2,9 +2,10 @@ from __future__ import annotations from typing import TYPE_CHECKING -from core.plugin.impl.model import PluginModelClient from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model import PluginModelClient + if TYPE_CHECKING: from core.model_manager import ModelManager from core.plugin.impl.model_runtime import PluginModelRuntime diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 322f78ab4e..90350f8400 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,7 +1,8 @@ from typing import Any +from graphon.file import File + from core.tools.entities.tool_entities import ToolSelector -from graphon.file.models import File def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index de87a09652..19b5e9223a 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,15 +1,7 @@ from collections.abc import Mapping, Sequence from typing import cast -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from graphon.file import file_manager -from graphon.file.models import File +from graphon.file import File, file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, @@ -21,6 +13,14 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.runtime import VariablePool +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser + class AdvancedPromptTransform(PromptTransform): """ diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 8f1d51f08a..9be70199b7 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -1,10 +1,5 @@ from typing import cast -from core.app.entities.app_invoke_entities import ( - ModelConfigWithCredentialsEntity, -) -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.prompt_transform import PromptTransform from graphon.model_runtime.entities.message_entities import ( PromptMessage, SystemPromptMessage, @@ -12,6 +7,12 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.app.entities.app_invoke_entities import ( + ModelConfigWithCredentialsEntity, +) +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.prompt_transform import PromptTransform + class AgentHistoryPromptTransform(PromptTransform): """ diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 6ff2f44cdc..4539ae9f11 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,11 +1,12 @@ from typing import Any +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from graphon.model_runtime.entities.message_entities import PromptMessage -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey class PromptTransform: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index e091215b80..c706353ffe 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -4,12 +4,6 @@ from collections.abc import Mapping, Sequence from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, cast -from core.app.app_config.entities import PromptTemplateEntity -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import file_manager from graphon.model_runtime.entities.message_entities import ( ImagePromptMessageContent, @@ -19,10 +13,17 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) + +from core.app.app_config.entities import PromptTemplateEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode if TYPE_CHECKING: - from graphon.file.models import File + from graphon.file import File class ModelMode(StrEnum): diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index ba76eb0c4e..dbda749925 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,7 +1,6 @@ from collections.abc import Sequence from typing import Any, cast -from core.prompt.simple_prompt_transform import ModelMode from graphon.model_runtime.entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -12,6 +11,8 @@ from graphon.model_runtime.entities import ( TextPromptMessageContent, ) +from core.prompt.simple_prompt_transform import ModelMode + class PromptMessageUtil: @staticmethod diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 79fd78fe80..30933239f6 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -7,6 +7,14 @@ from collections.abc import Sequence from json import JSONDecodeError from typing import TYPE_CHECKING, Any, cast +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -33,14 +41,6 @@ from core.helper.position_helper import is_filtered from extensions import ext_hosting_provider from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ( LoadBalancingModelConfig, Provider, diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 2c81653559..b872ea8a8f 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,3 +1,5 @@ +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from typing_extensions import TypedDict from core.model_manager import ModelInstance, ModelManager @@ -8,8 +10,6 @@ from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_type import RerankMode -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class RerankingModelDict(TypedDict): diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 1e4aa24287..cc6ec12c75 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, NotRequired from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm import Session, load_only from typing_extensions import TypedDict @@ -24,7 +25,6 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ( ChildChunk, Dataset, diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index a77458706a..5a8d3a2f3f 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -4,6 +4,7 @@ import time from abc import ABC, abstractmethod from typing import Any +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from configs import dify_config @@ -17,7 +18,6 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Whitelist from models.model import UploadFile diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 369159767e..e5b794f80d 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -3,13 +3,13 @@ from __future__ import annotations from collections.abc import Sequence from typing import Any +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import func, select from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index b12a0ae2d6..3bdad00712 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -4,6 +4,8 @@ import pickle from typing import Any, cast import numpy as np +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from sqlalchemy.exc import IntegrityError from configs import dify_config @@ -12,8 +14,6 @@ from core.model_manager import ModelInstance from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from libs import helper from models.dataset import Embedding diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 9f36b7a225..5c10ffbf2d 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -8,6 +8,17 @@ from typing import Any, cast logger = logging.getLogger(__name__) +from graphon.file import File, FileTransferMethod, FileType, file_manager +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType + from core.app.file_access import DatabaseFileAccessController from core.app.llm import deduct_llm_quota from core.entities.knowledge_entities import PreviewDetail @@ -31,16 +42,6 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols from core.workflow.file_reference import build_file_reference from extensions.ext_database import db from factories.file_factory import build_from_mapping -from graphon.file import File, FileTransferMethod, FileType, file_manager -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentUnionTypes, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from libs import helper from models import UploadFile from models.account import Account diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 4ebf095904..087736d0b0 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,9 +2,8 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any -from pydantic import BaseModel, Field - from graphon.file import File +from pydantic import BaseModel, Field class ChildDocument(BaseModel): diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 6c6b077cc2..211a9f5c5c 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,5 +1,8 @@ import base64 +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.rerank_entities import RerankResult + from core.model_manager import ModelInstance, ModelManager from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.query_type import QueryType @@ -7,8 +10,6 @@ from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner from extensions.ext_database import db from extensions.ext_storage import storage -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.rerank_entities import RerankResult from models.model import UploadFile diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index d0732b269a..49123e13d0 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -2,6 +2,7 @@ import math from collections import Counter import numpy as np +from graphon.model_runtime.entities.model_entities import ModelType from core.model_manager import ModelManager from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler @@ -11,7 +12,6 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner -from graphon.model_runtime.entities.model_entities import ModelType class WeightRerankRunner(BaseRerankRunner): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 49b91707ec..1abea6639e 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -9,6 +9,11 @@ from collections.abc import Generator, Mapping from typing import Any, Union, cast from flask import Flask, current_app +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import and_, func, literal, or_, select from sqlalchemy.orm import Session @@ -66,11 +71,6 @@ from core.workflow.nodes.knowledge_retrieval.retrieval import ( ) from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.file import File, FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from libs.helper import parse_uuid_str_or_none from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index e617a9660e..dce7b6226c 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -1,9 +1,10 @@ from typing import Union +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage class FunctionCallMultiDatasetRouter: diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 83e58fe0f9..dd280cdf6a 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,6 +1,10 @@ from collections.abc import Generator, Sequence from typing import Union +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.llm import deduct_llm_quota from core.model_manager import ModelInstance, ModelManager @@ -8,9 +12,6 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from graphon.model_runtime.entities.model_entities import ModelType PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 2c27ac3cf6..e6aec4a3af 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -6,6 +6,8 @@ import codecs import re from typing import Any +from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer + from core.model_manager import ModelInstance from core.rag.splitter.text_splitter import ( TS, @@ -15,7 +17,6 @@ from core.rag.splitter.text_splitter import ( Set, Union, ) -from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index d0164b76dc..465f43da73 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -8,11 +8,11 @@ providing improved performance by offloading database operations to background w import logging from typing import Union +from graphon.entities import WorkflowExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository -from graphon.entities.workflow_execution import WorkflowExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 52361cf6dc..22ef44b3dc 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -9,6 +9,7 @@ import logging from collections.abc import Sequence from typing import Union +from graphon.entities import WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -16,7 +17,6 @@ from core.repositories.factory import ( OrderConfig, WorkflowNodeExecutionRepository, ) -from graphon.entities.workflow_node_execution import WorkflowNodeExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index dafdbf641a..ed6d44f434 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -9,11 +9,11 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Literal, Protocol, Union +from graphon.entities import WorkflowExecution, WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config -from graphon.entities import WorkflowExecution, WorkflowNodeExecution from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 02625e242f..72d9394149 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -4,6 +4,8 @@ from collections.abc import Mapping, Sequence from datetime import datetime from typing import Any, Protocol +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import select from sqlalchemy.orm import Session, selectinload @@ -17,8 +19,6 @@ from core.workflow.human_input_compat import ( InteractiveSurfaceDeliveryMethod, is_human_input_webapp_enabled, ) -from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.account import Account, TenantAccountJoin diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 1ee5d4ae77..85d20b675d 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -6,13 +6,13 @@ import json import logging from typing import Union +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 749ab44a14..a72bfa378b 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -10,6 +10,10 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, TypeVar, Union import psycopg2.errors +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import UnaryExpression, asc, desc, select from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError @@ -19,10 +23,6 @@ from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_att from configs import dify_config from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.ext_storage import storage -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from libs.uuid_utils import uuidv7 from models import ( diff --git a/api/core/telemetry/__init__.py b/api/core/telemetry/__init__.py new file mode 100644 index 0000000000..ae4f53f3b7 --- /dev/null +++ b/api/core/telemetry/__init__.py @@ -0,0 +1,43 @@ +"""Telemetry facade. + +Thin public API for emitting telemetry events. All routing logic +lives in ``core.telemetry.gateway`` which is shared by both CE and EE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.events import TelemetryContext, TelemetryEvent +from core.telemetry.gateway import emit as gateway_emit +from core.telemetry.gateway import get_trace_task_to_case + +if TYPE_CHECKING: + from core.ops.ops_trace_manager import TraceQueueManager + + +def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None: + """Emit a telemetry event. + + Translates the ``TelemetryEvent`` (keyed by ``TraceTaskName``) into a + ``TelemetryCase`` and delegates to ``core.telemetry.gateway.emit()``. + """ + case = get_trace_task_to_case().get(event.name) + if case is None: + return + + context: dict[str, object] = { + "tenant_id": event.context.tenant_id, + "user_id": event.context.user_id, + "app_id": event.context.app_id, + } + gateway_emit(case, context, event.payload, trace_manager) + + +__all__ = [ + "TelemetryContext", + "TelemetryEvent", + "TraceTaskName", + "emit", +] diff --git a/api/core/telemetry/events.py b/api/core/telemetry/events.py new file mode 100644 index 0000000000..35ace47510 --- /dev/null +++ b/api/core/telemetry/events.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from core.ops.entities.trace_entity import TraceTaskName + + +@dataclass(frozen=True) +class TelemetryContext: + tenant_id: str | None = None + user_id: str | None = None + app_id: str | None = None + + +@dataclass(frozen=True) +class TelemetryEvent: + name: TraceTaskName + context: TelemetryContext + payload: dict[str, Any] diff --git a/api/core/telemetry/gateway.py b/api/core/telemetry/gateway.py new file mode 100644 index 0000000000..7b013d0563 --- /dev/null +++ b/api/core/telemetry/gateway.py @@ -0,0 +1,239 @@ +"""Telemetry gateway — single routing layer for all editions. + +Maps ``TelemetryCase`` → ``CaseRoute`` and dispatches events to either +the CE/EE trace pipeline (``TraceQueueManager``) or the enterprise-only +metric/log Celery queue. + +This module lives in ``core/`` so both CE and EE share one routing table +and one ``emit()`` entry point. No separate enterprise gateway module is +needed — enterprise-specific dispatch (Celery task, payload offloading) +is handled here behind lazy imports that no-op in CE. +""" + +from __future__ import annotations + +import json +import logging +import uuid +from typing import TYPE_CHECKING, Any + +from core.ops.entities.trace_entity import TraceTaskName +from enterprise.telemetry.contracts import CaseRoute, SignalType +from extensions.ext_storage import storage + +if TYPE_CHECKING: + from core.ops.ops_trace_manager import TraceQueueManager + from enterprise.telemetry.contracts import TelemetryCase + +logger = logging.getLogger(__name__) + +PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024 + +# --------------------------------------------------------------------------- +# Routing table — authoritative mapping for all editions +# --------------------------------------------------------------------------- + +_case_to_trace_task: dict[TelemetryCase, TraceTaskName] | None = None +_case_routing: dict[TelemetryCase, CaseRoute] | None = None + + +def _get_case_to_trace_task() -> dict[TelemetryCase, TraceTaskName]: + global _case_to_trace_task + if _case_to_trace_task is None: + from enterprise.telemetry.contracts import TelemetryCase + + _case_to_trace_task = { + TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE, + TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE, + TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE, + TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE, + TelemetryCase.TOOL_EXECUTION: TraceTaskName.TOOL_TRACE, + TelemetryCase.MODERATION_CHECK: TraceTaskName.MODERATION_TRACE, + TelemetryCase.SUGGESTED_QUESTION: TraceTaskName.SUGGESTED_QUESTION_TRACE, + TelemetryCase.DATASET_RETRIEVAL: TraceTaskName.DATASET_RETRIEVAL_TRACE, + TelemetryCase.GENERATE_NAME: TraceTaskName.GENERATE_NAME_TRACE, + } + return _case_to_trace_task + + +def get_trace_task_to_case() -> dict[TraceTaskName, TelemetryCase]: + """Return TraceTaskName → TelemetryCase (inverse of _get_case_to_trace_task).""" + return {v: k for k, v in _get_case_to_trace_task().items()} + + +def _get_case_routing() -> dict[TelemetryCase, CaseRoute]: + global _case_routing + if _case_routing is None: + from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase + + _case_routing = { + # TRACE — CE-eligible (flow in both CE and EE) + TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + # TRACE — enterprise-only + TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + # METRIC_LOG — enterprise-only (signal-driven, not trace) + TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + } + return _case_routing + + +def __getattr__(name: str) -> dict: + """Lazy module-level access to routing tables.""" + if name == "CASE_ROUTING": + return _get_case_routing() + if name == "CASE_TO_TRACE_TASK": + return _get_case_to_trace_task() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def is_enterprise_telemetry_enabled() -> bool: + try: + from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled + + return is_enterprise_telemetry_enabled() + except Exception: + return False + + +def _handle_payload_sizing( + payload: dict[str, Any], + tenant_id: str, + event_id: str, +) -> tuple[dict[str, Any], str | None]: + """Inline or offload payload based on size. + + Returns ``(payload_for_envelope, storage_key | None)``. Payloads + exceeding ``PAYLOAD_SIZE_THRESHOLD_BYTES`` are written to object + storage and replaced with an empty dict in the envelope. + """ + try: + payload_json = json.dumps(payload) + payload_size = len(payload_json.encode("utf-8")) + except (TypeError, ValueError): + logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id) + return payload, None + + if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES: + return payload, None + + storage_key = f"telemetry/{tenant_id}/{event_id}.json" + try: + storage.save(storage_key, payload_json.encode("utf-8")) + logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size) + return {}, storage_key + except Exception: + logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True) + return payload, None + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def emit( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None = None, +) -> None: + """Route a telemetry event to the correct pipeline. + + TRACE events are enqueued into ``TraceQueueManager`` (works in both CE + and EE). Enterprise-only traces are silently dropped when EE is + disabled. + + METRIC_LOG events are dispatched to the enterprise Celery queue; + silently dropped when enterprise telemetry is unavailable. + """ + route = _get_case_routing().get(case) + if route is None: + logger.warning("Unknown telemetry case: %s, dropping event", case) + return + + if not route.ce_eligible and not is_enterprise_telemetry_enabled(): + logger.debug("Dropping EE-only event: case=%s (EE disabled)", case) + return + + if route.signal_type == SignalType.TRACE: + _emit_trace(case, context, payload, trace_manager) + else: + _emit_metric_log(case, context, payload) + + +def _emit_trace( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None, +) -> None: + from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager + from core.ops.ops_trace_manager import TraceTask + + trace_task_name = _get_case_to_trace_task().get(case) + if trace_task_name is None: + logger.warning("No TraceTaskName mapping for case: %s", case) + return + + queue_manager = trace_manager or LocalTraceQueueManager( + app_id=context.get("app_id"), + user_id=context.get("user_id"), + ) + queue_manager.add_trace_task(TraceTask(trace_task_name, user_id=context.get("user_id"), **payload)) + logger.debug("Enqueued trace task: case=%s, app_id=%s", case, context.get("app_id")) + + +def _emit_metric_log( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], +) -> None: + """Build envelope and dispatch to enterprise Celery queue. + + No-ops when the enterprise telemetry task is not importable (CE mode). + """ + try: + from tasks.enterprise_telemetry_task import process_enterprise_telemetry + except ImportError: + logger.debug("Enterprise metric/log dispatch unavailable, dropping: case=%s", case) + return + + tenant_id = context.get("tenant_id") or "" + event_id = str(uuid.uuid4()) + + payload_for_envelope, payload_ref = _handle_payload_sizing(payload, tenant_id, event_id) + + from enterprise.telemetry.contracts import TelemetryEnvelope + + envelope = TelemetryEnvelope( + case=case, + tenant_id=tenant_id, + event_id=event_id, + payload=payload_for_envelope, + metadata={"payload_ref": payload_ref} if payload_ref else None, + ) + + process_enterprise_telemetry.delay(envelope.model_dump_json()) + logger.debug( + "Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s", + case, + tenant_id, + event_id, + ) diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index 40bf2e98c2..e539074303 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -2,14 +2,15 @@ import io from collections.abc import Generator from typing import Any +from graphon.file import FileType +from graphon.file.file_manager import download +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from graphon.file.enums import FileType -from graphon.file.file_manager import download -from graphon.model_runtime.entities.model_entities import ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index ac3820f1ab..f49c669fe0 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -2,12 +2,13 @@ import io from collections.abc import Generator from typing import Any +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType + from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index d41503e1e6..14af63a962 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,11 +1,12 @@ from __future__ import annotations +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage + from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language and you can quickly aimed at the main point of an webpage and reproduce it in your own words but diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 168e5f4493..0a2c37c563 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -6,6 +6,7 @@ from typing import Any, Union from urllib.parse import urlencode import httpx +from graphon.file.file_manager import download from core.helper import ssrf_proxy from core.tools.__base.tool import Tool @@ -13,7 +14,6 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError -from graphon.file.file_manager import download API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 08640befb4..d5d3d1b1d9 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -2,6 +2,7 @@ from collections.abc import Mapping from datetime import datetime from typing import Any, Literal +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration @@ -9,7 +10,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from graphon.model_runtime.utils.encoders import jsonable_encoder class ToolApiEntity(BaseModel): diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 00fc8a8282..f6d09472b3 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -6,6 +6,8 @@ import logging from collections.abc import Generator, Mapping from typing import Any, cast +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata + from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPConnectionError from core.mcp.types import ( @@ -21,7 +23,6 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata logger = logging.getLogger(__name__) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 1fd259f3bb..685d687d8c 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -7,6 +7,7 @@ from datetime import UTC, datetime from mimetypes import guess_type from typing import Any, Union, cast +from graphon.file import FileTransferMethod, FileType from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom @@ -32,8 +33,6 @@ from core.tools.errors import ( from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db -from graphon.file import FileType -from graphon.file.models import FileTransferMethod from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import Message, MessageFile diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 2ec292602c..7ac29cf069 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -10,13 +10,13 @@ from typing import Union from uuid import uuid4 import httpx +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from configs import dify_config from core.db.session_factory import session_factory from core.helper import ssrf_proxy from core.workflow.file_reference import build_file_reference from extensions.ext_storage import storage -from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from models.model import MessageFile from models.tools import ToolFile diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 250dd91bfd..58190d1089 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -1,4 +1,4 @@ -from sqlalchemy import select +from sqlalchemy import delete, select from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -31,7 +31,7 @@ class ToolLabelManager: raise ValueError("Unsupported tool type") # delete old labels - db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete() + db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id)) # insert new labels for label in labels: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 4870adb7b5..a58d310313 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -8,6 +8,7 @@ from threading import Lock from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast import sqlalchemy as sa +from graphon.runtime import VariablePool from sqlalchemy import select from sqlalchemy.orm import Session from yarl import URL @@ -25,7 +26,6 @@ from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from extensions.ext_database import db -from graphon.runtime.variable_pool import VariablePool from models.provider_ids import ToolProviderID from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService @@ -33,11 +33,12 @@ from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: pass +from graphon.model_runtime.utils.encoders import jsonable_encoder + from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source from core.helper.position_helper import is_filtered -from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import Tool from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort @@ -57,7 +58,6 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool -from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -255,11 +255,11 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f"no default provider for {provider_id}") else: - builtin_provider = ( - db.session.query(BuiltinToolProvider) + builtin_provider = db.session.scalar( + select(BuiltinToolProvider) .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .first() + .limit(1) ) if builtin_provider is None: @@ -326,7 +326,7 @@ class ToolManager: tenant_id=tenant_id, user_id=user_id, credentials=dict(decrypted_credentials), - credential_type=CredentialType.of(builtin_provider.credential_type), + credential_type=builtin_provider.credential_type, runtime_parameters={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -818,13 +818,13 @@ class ToolManager: :return: the provider controller, the credentials """ - provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + provider: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where( ApiToolProvider.id == provider_id, ApiToolProvider.tenant_id == tenant_id, ) - .first() + .limit(1) ) if provider is None: @@ -872,13 +872,13 @@ class ToolManager: get api provider """ provider_name = provider - provider_obj: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + provider_obj: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider, ) - .first() + .limit(1) ) if provider_obj is None: @@ -964,10 +964,10 @@ class ToolManager: @classmethod def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: - workflow_provider: WorkflowToolProvider | None = ( - db.session.query(WorkflowToolProvider) + workflow_provider: WorkflowToolProvider | None = db.session.scalar( + select(WorkflowToolProvider) .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) - .first() + .limit(1) ) if workflow_provider is None: @@ -981,10 +981,10 @@ class ToolManager: @classmethod def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: - api_provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + api_provider: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) - .first() + .limit(1) ) if api_provider is None: diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index dad5133a7a..e63435db98 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,6 +1,7 @@ import threading from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select @@ -15,7 +16,6 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model: DefaultRetrievalModelDict = { @@ -110,7 +110,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): context_list: list[RetrievalSourceMetadata] = [] resource_number = 1 for segment in sorted_segments: - dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() + dataset = db.session.get(Dataset, segment.dataset_id) document_stmt = select(Document).where( Document.id == segment.document_id, Document.enabled == True, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index f3d390ed59..cbd8bdb36c 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -205,7 +205,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if self.return_resource: for record in records: segment = record.segment - dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() + dataset = db.session.get(Dataset, segment.dataset_id) dataset_document_stmt = select(DatasetDocument).where( DatasetDocument.id == segment.document_id, DatasetDocument.enabled == True, diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 5cf46b2564..bb5b3ba76e 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -8,11 +8,11 @@ from uuid import UUID import numpy as np import pytz +from graphon.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference -from graphon.file import File, FileTransferMethod, FileType from libs.login import current_user from models import Account diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 9e1d41cb39..8d6f83dc07 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -8,9 +8,6 @@ import json from decimal import Decimal from typing import cast -from core.model_manager import ModelManager -from core.tools.entities.tool_entities import ToolProviderType -from extensions.ext_database import db from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType @@ -23,6 +20,10 @@ from graphon.model_runtime.errors.invoke import ( ) from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from graphon.model_runtime.utils.encoders import jsonable_encoder + +from core.model_manager import ModelManager +from core.tools.entities.tool_entities import ToolProviderType +from extensions.ext_database import db from models.tools import ToolModelInvoke diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 1e4f3ed2a7..c4b7d57449 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -1,12 +1,13 @@ from collections.abc import Mapping, Sequence from typing import Any -from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration -from core.tools.errors import WorkflowToolHumanInputNotSupportedError from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.entities import OutputVariableEntity from graphon.variables.input_entities import VariableEntity +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration +from core.tools.errors import WorkflowToolHumanInputNotSupportedError + class WorkflowToolConfigurationUtils: @classmethod diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 716368c191..f48b24be30 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections.abc import Mapping +from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import Field from sqlalchemy.orm import Session @@ -23,7 +24,6 @@ from core.tools.entities.tool_entities import ( from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 495fcd48b3..a3fb4eda92 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -5,6 +5,8 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import Any, cast +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController @@ -20,8 +22,6 @@ from core.tools.entities.tool_entities import ( from core.tools.errors import ToolInvokeError from core.workflow.file_reference import resolve_file_record_id from factories.file_factory import build_from_mapping -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from models import Account, Tenant from models.model import App, EndUser from models.utils.file_input_compat import build_file_from_stored_mapping diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 24c1271488..61d1cd8540 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -8,6 +8,7 @@ from collections.abc import Mapping from datetime import datetime from typing import Any +from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from core.plugin.entities.request import TriggerInvokeEventResponse @@ -27,7 +28,6 @@ from core.trigger.debug.events import ( from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig from extensions.ext_redis import redis_client -from graphon.entities.graph_config import NodeConfigDict from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.schedule_utils import calculate_next_run_at from models.model import App diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_compat.py index 75a0a0c202..c95516a240 100644 --- a/api/core/workflow/human_input_compat.py +++ b/api/core/workflow/human_input_compat.py @@ -14,13 +14,12 @@ from typing import Annotated, Any, ClassVar, Literal import bleach import markdown -from markdown.extensions.tables import TableExtension -from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter - from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.variable_template_parser import VariableTemplateParser from graphon.runtime import VariablePool from graphon.variables.consts import SELECTORS_LENGTH +from markdown.extensions.tables import TableExtension +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter class DeliveryMethodType(enum.StrEnum): diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 028e38fbee..8cc21d2cd9 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -4,6 +4,22 @@ from collections.abc import Callable, Iterator, Mapping, MutableMapping from functools import lru_cache from typing import TYPE_CHECKING, Any, TypeAlias, cast, final +from graphon.entities.base_node_data import BaseNodeData +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.file.file_manager import file_manager +from graphon.graph.graph import NodeFactory +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.nodes.base.node import Node +from graphon.nodes.code.code_node import WorkflowCodeExecutor +from graphon.nodes.code.entities import CodeLanguage +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.nodes.document_extractor import UnstructuredApiConfig +from graphon.nodes.http_request import build_http_request_config +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from sqlalchemy import select from sqlalchemy.orm import Session from typing_extensions import override @@ -40,22 +56,6 @@ from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer from extensions.ext_database import db -from graphon.entities.base_node_data import BaseNodeData -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.file.file_manager import file_manager -from graphon.graph.graph import NodeFactory -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.nodes.base.node import Node -from graphon.nodes.code.code_node import WorkflowCodeExecutor -from graphon.nodes.code.entities import CodeLanguage -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.nodes.document_extractor import UnstructuredApiConfig -from graphon.nodes.http_request import build_http_request_config -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from models.model import Conversation if TYPE_CHECKING: diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index 2e632e56f0..19cb3a7b0a 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -4,32 +4,6 @@ from collections.abc import Callable, Generator, Mapping, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.app.file_access import DatabaseFileAccessController -from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from core.model_manager import ModelInstance -from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError -from core.plugin.impl.plugin import PluginInstaller -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormRepository, - HumanInputFormRepositoryImpl, -) -from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType -from core.tools.errors import ToolInvokeError -from core.tools.tool_engine import ToolEngine -from core.tools.tool_file_manager import ToolFileManager -from core.tools.tool_manager import ToolManager -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.file_reference import build_file_reference -from extensions.ext_database import db -from factories import file_factory from graphon.file import FileTransferMethod, FileType from graphon.model_runtime.entities import LLMMode from graphon.model_runtime.entities.llm_entities import ( @@ -60,6 +34,32 @@ from graphon.nodes.tool_runtime_entities import ( ToolRuntimeMessage, ToolRuntimeParameter, ) +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.file_access import DatabaseFileAccessController +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output +from core.model_manager import ModelInstance +from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError +from core.plugin.impl.plugin import PluginInstaller +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormRepository, + HumanInputFormRepositoryImpl, +) +from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.tool_file_manager import ToolFileManager +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.file_reference import build_file_reference +from extensions.ext_database import db +from factories import file_factory from models.dataset import SegmentAttachmentBinding from models.model import UploadFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService @@ -76,12 +76,13 @@ from .human_input_compat import ( from .system_variables import SystemVariableKey, get_system_text if TYPE_CHECKING: - from core.tools.__base.tool import Tool - from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage from graphon.file import File from graphon.nodes.llm.file_saver import LLMFileSaver from graphon.nodes.tool.entities import ToolNodeData + from core.tools.__base.tool import Tool + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage + _file_access_controller = DatabaseFileAccessController() diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 7b000101b0..bfd5536e4a 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -3,14 +3,15 @@ from __future__ import annotations from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.workflow.system_variables import SystemVariableKey, get_system_text from graphon.entities.graph_config import NodeConfigDict from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent from graphon.nodes.base.node import Node from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.workflow.system_variables import SystemVariableKey, get_system_text + from .entities import AgentNodeData from .exceptions import ( AgentInvocationError, diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 51452c29a3..c52aad150b 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,12 +1,12 @@ from enum import IntEnum, StrEnum, auto from typing import Any, Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType class AgentNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py index f44681377d..db74590ed7 100644 --- a/api/core/workflow/nodes/agent/message_transformer.py +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -3,14 +3,6 @@ from __future__ import annotations from collections.abc import Generator, Mapping from typing import Any, cast -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.app.file_access import DatabaseFileAccessController -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from extensions.ext_database import db -from factories import file_factory from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata @@ -23,6 +15,14 @@ from graphon.node_events import ( StreamCompletedEvent, ) from graphon.variables.segments import ArrayFileSegment +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.file_access import DatabaseFileAccessController +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from extensions.ext_database import db +from factories import file_factory from models import ToolFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py index a872774c98..be50edbc4d 100644 --- a/api/core/workflow/nodes/agent/runtime_support.py +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -4,6 +4,8 @@ import json from collections.abc import Sequence from typing import Any, cast +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.runtime import VariablePool from packaging.version import Version from pydantic import ValidationError from sqlalchemy import select @@ -19,8 +21,6 @@ from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolP from core.tools.tool_manager import ToolManager from core.workflow.system_variables import SystemVariableKey, get_system_text from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.runtime import VariablePool from models.model import Conversation from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 38f39b3f94..d9247b2593 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,18 +1,23 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( + BuiltinNodeTypes, + NodeExecutionType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.node_events import NodeRunResult, StreamCompletedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser + from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError from core.workflow.file_reference import resolve_file_record_id from core.workflow.system_variables import SystemVariableKey, get_system_segment -from graphon.entities.graph_config import NodeConfigDict -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionMetadataKey -from graphon.node_events import NodeRunResult, StreamCompletedEvent -from graphon.nodes.base.node import Node -from graphon.nodes.base.variable_template_parser import VariableTemplateParser from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam from .exc import DatasourceNodeError diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 28966f2392..cad32f8d5b 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -1,10 +1,9 @@ from typing import Any, Literal, Union -from pydantic import BaseModel, field_validator -from pydantic_core.core_schema import ValidationInfo - from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType +from pydantic import BaseModel, field_validator +from pydantic_core.core_schema import ValidationInfo class DatasourceEntity(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 11339bb122..cba6c12dca 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -1,12 +1,12 @@ from typing import Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index b465a2d8ff..bb72fe3881 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -2,17 +2,17 @@ import logging from collections.abc import Mapping from typing import TYPE_CHECKING, Any +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template + from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text -from graphon.entities.graph_config import NodeConfigDict -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import NodeExecutionType -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.base.template import Template from .entities import KnowledgeIndexNodeData from .exc import ( diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 3f7cc364d3..b1fa8593ef 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,11 +1,10 @@ from collections.abc import Sequence from typing import Literal -from pydantic import BaseModel, Field - from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.llm.entities import ModelConfig, VisionConfig +from pydantic import BaseModel, Field class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 117f426ade..13624b27b3 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,11 +8,6 @@ import logging from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal -from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.workflow.file_reference import parse_file_reference from graphon.entities import GraphInitParams from graphon.entities.graph_config import NodeConfigDict from graphon.enums import ( @@ -32,6 +27,12 @@ from graphon.variables import ( ) from graphon.variables.segments import ArrayObjectSegment +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.file_reference import parse_file_reference + from .entities import ( Condition, KnowledgeRetrievalNodeData, @@ -44,7 +45,7 @@ from .exc import ( from .retrieval import KnowledgeRetrievalRequest, Source if TYPE_CHECKING: - from graphon.file.models import File + from graphon.file import File from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) diff --git a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py index ea45dcf5c2..39e2008a2c 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -1,10 +1,10 @@ from typing import Any, Literal, Protocol +from graphon.model_runtime.entities import LLMUsage +from graphon.nodes.llm.entities import ModelConfig from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from graphon.model_runtime.entities import LLMUsage -from graphon.nodes.llm.entities import ModelConfig from .entities import MetadataFilteringCondition diff --git a/api/core/workflow/nodes/trigger_plugin/entities.py b/api/core/workflow/nodes/trigger_plugin/entities.py index 23ed2cd408..bf5be2379a 100644 --- a/api/core/workflow/nodes/trigger_plugin/entities.py +++ b/api/core/workflow/nodes/trigger_plugin/entities.py @@ -1,12 +1,12 @@ from collections.abc import Mapping from typing import Any, Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel, Field, ValidationInfo, field_validator from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.entities.entities import EventParameter -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from .exc import TriggerEventParameterError diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index a2c952a899..e50de11bb9 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -1,13 +1,13 @@ from collections.abc import Mapping from typing import Any -from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID + from .entities import TriggerEventNodeData diff --git a/api/core/workflow/nodes/trigger_schedule/entities.py b/api/core/workflow/nodes/trigger_schedule/entities.py index 207c1e7253..f14ca893c9 100644 --- a/api/core/workflow/nodes/trigger_schedule/entities.py +++ b/api/core/workflow/nodes/trigger_schedule/entities.py @@ -1,10 +1,10 @@ from typing import Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel, Field from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType class TriggerScheduleNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index dd80617dfc..a9753ab387 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,11 +1,11 @@ from collections.abc import Mapping +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node + from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import NodeExecutionType -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node from .entities import TriggerScheduleNodeData diff --git a/api/core/workflow/nodes/trigger_webhook/entities.py b/api/core/workflow/nodes/trigger_webhook/entities.py index 3125fe17e6..4d5ad72154 100644 --- a/api/core/workflow/nodes/trigger_webhook/entities.py +++ b/api/core/workflow/nodes/trigger_webhook/entities.py @@ -1,12 +1,12 @@ from collections.abc import Sequence from enum import StrEnum -from pydantic import BaseModel, Field, field_validator - -from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from graphon.entities.base_node_data import BaseNodeData from graphon.enums import NodeType from graphon.variables.types import SegmentType +from pydantic import BaseModel, Field, field_validator + +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE _WEBHOOK_HEADER_ALLOWED_TYPES = frozenset( { diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 6858d6dc35..ebaac93934 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -2,12 +2,7 @@ import logging from collections.abc import Mapping from typing import Any -from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from core.workflow.file_reference import resolve_file_record_id -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from factories.variable_factory import build_segment_with_type -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.enums import NodeExecutionType +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus from graphon.file import FileTransferMethod from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node @@ -15,6 +10,11 @@ from graphon.nodes.protocols import FileReferenceFactoryProtocol from graphon.variables.types import SegmentType from graphon.variables.variables import FileVariable +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from factories.variable_factory import build_segment_with_type + from .entities import ContentType, WebhookData logger = logging.getLogger(__name__) diff --git a/api/core/workflow/template_rendering.py b/api/core/workflow/template_rendering.py index b4ffb37549..d51cfadd09 100644 --- a/api/core/workflow/template_rendering.py +++ b/api/core/workflow/template_rendering.py @@ -3,10 +3,11 @@ from __future__ import annotations from collections.abc import Mapping from typing import Any -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor from graphon.nodes.code.entities import CodeLanguage from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor + class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): """Sandbox-backed Jinja2 renderer for workflow-owned node composition.""" diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 7429c95c7c..2346a95d6a 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -3,6 +3,20 @@ import time from collections.abc import Generator, Mapping, Sequence from typing import Any +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file import File +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import CommandChannel, InMemoryChannel +from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool + from configs import dify_config from context import capture_current_context from core.app.apps.exc import GenerateTaskStoppedError @@ -21,20 +35,6 @@ from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add from core.workflow.variable_prefixes import ENVIRONMENT_VARIABLE_NODE_ID from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.errors import WorkflowNodeRunFailedError -from graphon.file.models import File -from graphon.graph import Graph -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer -from graphon.graph_engine.protocols.command_channel import CommandChannel -from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.base.node import Node -from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow logger = logging.getLogger(__name__) diff --git a/api/graphon/__init__.py b/api/enterprise/__init__.py similarity index 100% rename from api/graphon/__init__.py rename to api/enterprise/__init__.py diff --git a/api/enterprise/telemetry/DATA_DICTIONARY.md b/api/enterprise/telemetry/DATA_DICTIONARY.md new file mode 100644 index 0000000000..60d482cd1c --- /dev/null +++ b/api/enterprise/telemetry/DATA_DICTIONARY.md @@ -0,0 +1,525 @@ +# Dify Enterprise Telemetry Data Dictionary + +Quick reference for all telemetry signals emitted by Dify Enterprise. For configuration and architecture details, see [README.md](./README.md). + +## Resource Attributes + +Attached to every signal (Span, Metric, Log). + +| Attribute | Type | Example | +|-----------|------|---------| +| `service.name` | string | `dify` | +| `host.name` | string | `dify-api-7f8b` | + +## Traces (Spans) + +### `dify.workflow.run` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.trace_id` | string | Business trace ID (Workflow Run ID) | +| `dify.tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.workflow.id` | string | Workflow definition ID | +| `dify.workflow.run_id` | string | Unique ID for this run | +| `dify.workflow.status` | string | `succeeded`, `failed`, `stopped`, etc. | +| `dify.workflow.error` | string | Error message if failed | +| `dify.workflow.elapsed_time` | float | Total execution time (seconds) | +| `dify.invoke_from` | string | `api`, `webapp`, `debug` | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.message.id` | string | Message ID (optional) | +| `dify.invoked_by` | string | User ID who triggered the run | +| `gen_ai.usage.total_tokens` | int | Total tokens across all nodes (optional) | +| `gen_ai.user.id` | string | End-user identifier (optional) | +| `dify.parent.trace_id` | string | Parent workflow trace ID (optional) | +| `dify.parent.workflow.run_id` | string | Parent workflow run ID (optional) | +| `dify.parent.node.execution_id` | string | Parent node execution ID (optional) | +| `dify.parent.app.id` | string | Parent app ID (optional) | + +### `dify.node.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.trace_id` | string | Business trace ID | +| `dify.tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.workflow.id` | string | Workflow definition ID | +| `dify.workflow.run_id` | string | Workflow Run ID | +| `dify.message.id` | string | Message ID (optional) | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.node.execution_id` | string | Unique node execution ID | +| `dify.node.id` | string | Node ID in workflow graph | +| `dify.node.type` | string | Node type (see appendix) | +| `dify.node.title` | string | Display title | +| `dify.node.status` | string | `succeeded`, `failed` | +| `dify.node.error` | string | Error message if failed | +| `dify.node.elapsed_time` | float | Execution time (seconds) | +| `dify.node.index` | int | Execution order index | +| `dify.node.predecessor_node_id` | string | Triggering node ID | +| `dify.node.iteration_id` | string | Iteration ID (optional) | +| `dify.node.loop_id` | string | Loop ID (optional) | +| `dify.node.parallel_id` | string | Parallel branch ID (optional) | +| `dify.node.invoked_by` | string | User ID who triggered execution | +| `gen_ai.usage.input_tokens` | int | Prompt tokens (LLM nodes only) | +| `gen_ai.usage.output_tokens` | int | Completion tokens (LLM nodes only) | +| `gen_ai.usage.total_tokens` | int | Total tokens (LLM nodes only) | +| `gen_ai.request.model` | string | LLM model name (LLM nodes only) | +| `gen_ai.provider.name` | string | LLM provider name (LLM nodes only) | +| `gen_ai.user.id` | string | End-user identifier (optional) | + +### `dify.node.execution.draft` + +Same attributes as `dify.node.execution`. Emitted during Preview/Debug runs. + +## Counters + +All counters are cumulative and emitted at 100% accuracy. + +### Token Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.tokens.total` | `{token}` | Total tokens consumed | +| `dify.tokens.input` | `{token}` | Input (prompt) tokens | +| `dify.tokens.output` | `{token}` | Output (completion) tokens | + +**Labels:** + +- `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type` (if node_execution) + +⚠️ **Warning:** `dify.tokens.total` at workflow level includes all node tokens. Filter by `operation_type` to avoid double-counting. + +#### Token Hierarchy & Query Patterns + +Token metrics are emitted at multiple layers. Understanding the hierarchy prevents double-counting: + +``` +App-level total +├── workflow ← sum of all node_execution tokens (DO NOT add both) +│ └── node_execution ← per-node breakdown +├── message ← independent (non-workflow chat apps only) +├── rule_generate ← independent helper LLM call +├── code_generate ← independent helper LLM call +├── structured_output ← independent helper LLM call +└── instruction_modify← independent helper LLM call +``` + +**Key rule:** `workflow` tokens already include all `node_execution` tokens. Never sum both. + +**Available labels on token metrics:** `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type`. +App name is only available on span attributes (`dify.app.name`), not metric labels — use `app_id` for metric queries. + +**Common queries** (PromQL): + +```promql +# ── Totals ────────────────────────────────────────────────── +# App-level total (exclude node_execution to avoid double-counting) +sum by (app_id) (dify_tokens_total{operation_type!="node_execution"}) + +# Single app total +sum (dify_tokens_total{app_id="", operation_type!="node_execution"}) + +# Per-tenant totals +sum by (tenant_id) (dify_tokens_total{operation_type!="node_execution"}) + +# ── Drill-down ────────────────────────────────────────────── +# Workflow-level tokens for an app +sum (dify_tokens_total{app_id="", operation_type="workflow"}) + +# Node-level breakdown within an app +sum by (node_type) (dify_tokens_total{app_id="", operation_type="node_execution"}) + +# Model breakdown for an app +sum by (model_provider, model_name) (dify_tokens_total{app_id=""}) + +# Input vs output per model +sum by (model_name) (dify_tokens_input_total{app_id=""}) +sum by (model_name) (dify_tokens_output_total{app_id=""}) + +# ── Rates ─────────────────────────────────────────────────── +# Token consumption rate (per hour) +sum(rate(dify_tokens_total{operation_type!="node_execution"}[1h])) + +# Per-app consumption rate +sum by (app_id) (rate(dify_tokens_total{operation_type!="node_execution"}[1h])) +``` + +**Finding `app_id` from app name** (trace query — Tempo / Jaeger): + +``` +{ resource.dify.app.name = "My Chatbot" } | select(resource.dify.app.id) +``` + +### Request Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.requests.total` | `{request}` | Total operations count | + +**Labels by type:** + +| `type` | Additional Labels | +|--------|-------------------| +| `workflow` | `tenant_id`, `app_id`, `status`, `invoke_from` | +| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` | +| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` | +| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name`, `status`, `invoke_from` | +| `tool` | `tenant_id`, `app_id`, `tool_name` | +| `moderation` | `tenant_id`, `app_id` | +| `suggested_question` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dataset_retrieval` | `tenant_id`, `app_id` | +| `generate_name` | `tenant_id`, `app_id` | +| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `status` | + +### Error Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.errors.total` | `{error}` | Total failed operations | + +**Labels by type:** + +| `type` | Additional Labels | +|--------|-------------------| +| `workflow` | `tenant_id`, `app_id` | +| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` | +| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` | +| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `tool` | `tenant_id`, `app_id`, `tool_name` | +| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` | + +### Other Counters + +| Metric | Unit | Labels | +|--------|------|--------| +| `dify.feedback.total` | `{feedback}` | `tenant_id`, `app_id`, `rating` | +| `dify.dataset.retrievals.total` | `{retrieval}` | `tenant_id`, `app_id`, `dataset_id`, `embedding_model_provider`, `embedding_model`, `rerank_model_provider`, `rerank_model` | +| `dify.app.created.total` | `{app}` | `tenant_id`, `app_id`, `mode` | +| `dify.app.updated.total` | `{app}` | `tenant_id`, `app_id` | +| `dify.app.deleted.total` | `{app}` | `tenant_id`, `app_id` | + +## Histograms + +| Metric | Unit | Labels | +|--------|------|--------| +| `dify.workflow.duration` | `s` | `tenant_id`, `app_id`, `status` | +| `dify.node.duration` | `s` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `plugin_name` | +| `dify.message.duration` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dify.message.time_to_first_token` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dify.tool.duration` | `s` | `tenant_id`, `app_id`, `tool_name` | +| `dify.prompt_generation.duration` | `s` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` | + +## Structured Logs + +### Span Companion Logs + +Logs that accompany spans. Signal type: `span_detail` + +#### `dify.workflow.run` Companion Log + +**Common attributes:** All span attributes (see Traces section) plus: + +| Additional Attribute | Type | Always Present | Description | +|---------------------|------|----------------|-------------| +| `dify.app.name` | string | No | Application display name | +| `dify.workspace.name` | string | No | Workspace display name | +| `dify.workflow.version` | string | Yes | Workflow definition version | +| `dify.workflow.inputs` | string/JSON | Yes | Input parameters (content-gated) | +| `dify.workflow.outputs` | string/JSON | Yes | Output results (content-gated) | +| `dify.workflow.query` | string | No | User query text (content-gated) | + +**Event attributes:** + +- `dify.event.name`: `"dify.workflow.run"` +- `dify.event.signal`: `"span_detail"` +- `trace_id`, `span_id`, `tenant_id`, `user_id` + +#### `dify.node.execution` and `dify.node.execution.draft` Companion Logs + +**Common attributes:** All span attributes (see Traces section) plus: + +| Additional Attribute | Type | Always Present | Description | +|---------------------|------|----------------|-------------| +| `dify.app.name` | string | No | Application display name | +| `dify.workspace.name` | string | No | Workspace display name | +| `dify.invoke_from` | string | No | Invocation source | +| `gen_ai.tool.name` | string | No | Tool name (tool nodes only) | +| `dify.node.total_price` | float | No | Cost (LLM nodes only) | +| `dify.node.currency` | string | No | Currency code (LLM nodes only) | +| `dify.node.iteration_index` | int | No | Iteration index (iteration nodes) | +| `dify.node.loop_index` | int | No | Loop index (loop nodes) | +| `dify.plugin.name` | string | No | Plugin name (tool/knowledge nodes) | +| `dify.credential.name` | string | No | Credential name (plugin nodes) | +| `dify.credential.id` | string | No | Credential ID (plugin nodes) | +| `dify.dataset.ids` | JSON array | No | Dataset IDs (knowledge nodes) | +| `dify.dataset.names` | JSON array | No | Dataset names (knowledge nodes) | +| `dify.node.inputs` | string/JSON | Yes | Node inputs (content-gated) | +| `dify.node.outputs` | string/JSON | Yes | Node outputs (content-gated) | +| `dify.node.process_data` | string/JSON | No | Processing data (content-gated) | + +**Event attributes:** + +- `dify.event.name`: `"dify.node.execution"` or `"dify.node.execution.draft"` +- `dify.event.signal`: `"span_detail"` +- `trace_id`, `span_id`, `tenant_id`, `user_id` + +### Standalone Logs + +Logs without structural spans. Signal type: `metric_only` + +#### `dify.message.run` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.message.run"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID (32-char hex) | +| `span_id` | string | OTEL span ID (16-char hex) | +| `tenant_id` | string | Tenant identifier | +| `user_id` | string | User identifier (optional) | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.workflow.run_id` | string | Workflow run ID (optional) | +| `dify.invoke_from` | string | `service-api`, `web-app`, `debugger`, `explore` | +| `gen_ai.provider.name` | string | LLM provider | +| `gen_ai.request.model` | string | LLM model | +| `gen_ai.usage.input_tokens` | int | Input tokens | +| `gen_ai.usage.output_tokens` | int | Output tokens | +| `gen_ai.usage.total_tokens` | int | Total tokens | +| `dify.message.status` | string | `succeeded`, `failed` | +| `dify.message.error` | string | Error message (if failed) | +| `dify.message.duration` | float | Duration (seconds) | +| `dify.message.time_to_first_token` | float | TTFT (seconds) | +| `dify.message.inputs` | string/JSON | Inputs (content-gated) | +| `dify.message.outputs` | string/JSON | Outputs (content-gated) | + +#### `dify.tool.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.tool.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.tool.name` | string | Tool name | +| `dify.tool.duration` | float | Duration (seconds) | +| `dify.tool.status` | string | `succeeded`, `failed` | +| `dify.tool.error` | string | Error message (if failed) | +| `dify.tool.inputs` | string/JSON | Inputs (content-gated) | +| `dify.tool.outputs` | string/JSON | Outputs (content-gated) | +| `dify.tool.parameters` | string/JSON | Parameters (content-gated) | +| `dify.tool.config` | string/JSON | Configuration (content-gated) | + +#### `dify.moderation.check` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.moderation.check"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.moderation.type` | string | `input`, `output` | +| `dify.moderation.action` | string | `pass`, `block`, `flag` | +| `dify.moderation.flagged` | boolean | Whether flagged | +| `dify.moderation.categories` | JSON array | Flagged categories | +| `dify.moderation.query` | string | Content (content-gated) | + +#### `dify.suggested_question.generation` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.suggested_question.generation"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.suggested_question.count` | int | Number of questions | +| `dify.suggested_question.duration` | float | Duration (seconds) | +| `dify.suggested_question.status` | string | `succeeded`, `failed` | +| `dify.suggested_question.error` | string | Error message (if failed) | +| `dify.suggested_question.questions` | JSON array | Questions (content-gated) | + +#### `dify.dataset.retrieval` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.dataset.retrieval"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.dataset.id` | string | Dataset identifier | +| `dify.dataset.name` | string | Dataset name | +| `dify.dataset.embedding_providers` | JSON array | Embedding model providers (one per dataset) | +| `dify.dataset.embedding_models` | JSON array | Embedding models (one per dataset) | +| `dify.retrieval.rerank_provider` | string | Rerank model provider | +| `dify.retrieval.rerank_model` | string | Rerank model name | +| `dify.retrieval.query` | string | Search query (content-gated) | +| `dify.retrieval.document_count` | int | Documents retrieved | +| `dify.retrieval.duration` | float | Duration (seconds) | +| `dify.retrieval.status` | string | `succeeded`, `failed` | +| `dify.retrieval.error` | string | Error message (if failed) | +| `dify.dataset.documents` | JSON array | Documents (content-gated) | + +#### `dify.generate_name.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.generate_name.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.conversation.id` | string | Conversation identifier | +| `dify.generate_name.duration` | float | Duration (seconds) | +| `dify.generate_name.status` | string | `succeeded`, `failed` | +| `dify.generate_name.error` | string | Error message (if failed) | +| `dify.generate_name.inputs` | string/JSON | Inputs (content-gated) | +| `dify.generate_name.outputs` | string | Generated name (content-gated) | + +#### `dify.prompt_generation.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.prompt_generation.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.prompt_generation.operation_type` | string | Operation type (see appendix) | +| `gen_ai.provider.name` | string | LLM provider | +| `gen_ai.request.model` | string | LLM model | +| `gen_ai.usage.input_tokens` | int | Input tokens | +| `gen_ai.usage.output_tokens` | int | Output tokens | +| `gen_ai.usage.total_tokens` | int | Total tokens | +| `dify.prompt_generation.duration` | float | Duration (seconds) | +| `dify.prompt_generation.status` | string | `succeeded`, `failed` | +| `dify.prompt_generation.error` | string | Error message (if failed) | +| `dify.prompt_generation.instruction` | string | Instruction (content-gated) | +| `dify.prompt_generation.output` | string/JSON | Output (content-gated) | + +#### `dify.app.created` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.created"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.mode` | string | `chat`, `completion`, `agent-chat`, `workflow` | +| `dify.app.created_at` | string | Timestamp (ISO 8601) | + +#### `dify.app.updated` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.updated"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.updated_at` | string | Timestamp (ISO 8601) | + +#### `dify.app.deleted` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.deleted"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.deleted_at` | string | Timestamp (ISO 8601) | + +#### `dify.feedback.created` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.feedback.created"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.feedback.rating` | string | `like`, `dislike`, `null` | +| `dify.feedback.content` | string | Feedback text (content-gated) | +| `dify.feedback.created_at` | string | Timestamp (ISO 8601) | + +#### `dify.telemetry.rehydration_failed` + +Diagnostic event for telemetry system health monitoring. + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.telemetry.rehydration_failed"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.telemetry.error` | string | Error message | +| `dify.telemetry.payload_type` | string | Payload type (see appendix) | +| `dify.telemetry.correlation_id` | string | Correlation ID | + +## Content-Gated Attributes + +When `ENTERPRISE_INCLUDE_CONTENT=false`, these attributes are replaced with reference strings (`ref:{id_type}={uuid}`). + +| Attribute | Signal | +|-----------|--------| +| `dify.workflow.inputs` | `dify.workflow.run` | +| `dify.workflow.outputs` | `dify.workflow.run` | +| `dify.workflow.query` | `dify.workflow.run` | +| `dify.node.inputs` | `dify.node.execution` | +| `dify.node.outputs` | `dify.node.execution` | +| `dify.node.process_data` | `dify.node.execution` | +| `dify.message.inputs` | `dify.message.run` | +| `dify.message.outputs` | `dify.message.run` | +| `dify.tool.inputs` | `dify.tool.execution` | +| `dify.tool.outputs` | `dify.tool.execution` | +| `dify.tool.parameters` | `dify.tool.execution` | +| `dify.tool.config` | `dify.tool.execution` | +| `dify.moderation.query` | `dify.moderation.check` | +| `dify.suggested_question.questions` | `dify.suggested_question.generation` | +| `dify.retrieval.query` | `dify.dataset.retrieval` | +| `dify.dataset.documents` | `dify.dataset.retrieval` | +| `dify.generate_name.inputs` | `dify.generate_name.execution` | +| `dify.generate_name.outputs` | `dify.generate_name.execution` | +| `dify.prompt_generation.instruction` | `dify.prompt_generation.execution` | +| `dify.prompt_generation.output` | `dify.prompt_generation.execution` | +| `dify.feedback.content` | `dify.feedback.created` | + +## Appendix + +### Operation Types + +- `workflow`, `node_execution`, `message`, `rule_generate`, `code_generate`, `structured_output`, `instruction_modify` + +### Node Types + +- `start`, `end`, `answer`, `llm`, `knowledge-retrieval`, `knowledge-index`, `if-else`, `code`, `template-transform`, `question-classifier`, `http-request`, `tool`, `datasource`, `variable-aggregator`, `loop`, `iteration`, `parameter-extractor`, `assigner`, `document-extractor`, `list-operator`, `agent`, `trigger-webhook`, `trigger-schedule`, `trigger-plugin`, `human-input` + +### Workflow Statuses + +- `running`, `succeeded`, `failed`, `stopped`, `partial-succeeded`, `paused` + +### Payload Types + +- `workflow`, `node`, `message`, `tool`, `moderation`, `suggested_question`, `dataset_retrieval`, `generate_name`, `prompt_generation`, `app`, `feedback` + +### Null Value Behavior + +**Spans:** Attributes with `null` values are omitted. + +**Logs:** Attributes with `null` values appear as `null` in JSON. + +**Content-Gated:** Replaced with reference strings, not set to `null`. diff --git a/api/enterprise/telemetry/README.md b/api/enterprise/telemetry/README.md new file mode 100644 index 0000000000..e43c0b1ea2 --- /dev/null +++ b/api/enterprise/telemetry/README.md @@ -0,0 +1,121 @@ +# Dify Enterprise Telemetry + +This document provides an overview of the Dify Enterprise OpenTelemetry (OTEL) exporter and how to configure it for integration with observability stacks like Prometheus, Grafana, Jaeger, or Honeycomb. + +## Overview + +Dify Enterprise uses a "slim span + rich companion log" architecture to provide high-fidelity observability without overwhelming trace storage. + +- **Traces (Spans)**: Capture the structure, identity, and timing of high-level operations (Workflows and Nodes). +- **Structured Logs**: Provide deep context (inputs, outputs, metadata) for every event, correlated to spans via `trace_id` and `span_id`. +- **Metrics**: Provide 100% accurate counters and histograms for usage, performance, and error tracking. + +### Signal Architecture + +```mermaid +graph TD + A[Workflow Run] -->|Span| B(dify.workflow.run) + A -->|Log| C(dify.workflow.run detail) + B ---|trace_id| C + + D[Node Execution] -->|Span| E(dify.node.execution) + D -->|Log| F(dify.node.execution detail) + E ---|span_id| F + + G[Message/Tool/etc] -->|Log| H(dify.* event) + G -->|Metric| I(dify.* counter/histogram) +``` + +## Configuration + +The Enterprise OTEL exporter is configured via environment variables. + +| Variable | Description | Default | +|----------|-------------|---------| +| `ENTERPRISE_ENABLED` | Master switch for all enterprise features. | `false` | +| `ENTERPRISE_TELEMETRY_ENABLED` | Master switch for enterprise telemetry. | `false` | +| `ENTERPRISE_OTLP_ENDPOINT` | OTLP collector endpoint (e.g., `http://otel-collector:4318`). | - | +| `ENTERPRISE_OTLP_HEADERS` | Custom headers for OTLP requests (e.g., `x-scope-orgid=tenant1`). | - | +| `ENTERPRISE_OTLP_PROTOCOL` | OTLP transport protocol (`http` or `grpc`). | `http` | +| `ENTERPRISE_OTLP_API_KEY` | Bearer token for authentication. | - | +| `ENTERPRISE_INCLUDE_CONTENT` | Whether to include sensitive content (inputs/outputs) in logs. | `false` | +| `ENTERPRISE_SERVICE_NAME` | Service name reported to OTEL. | `dify` | +| `ENTERPRISE_OTEL_SAMPLING_RATE` | Sampling rate for traces (0.0 to 1.0). Metrics are always 100%. | `1.0` | + +## Correlation Model + +Dify uses deterministic ID generation to ensure signals are correlated across different services and asynchronous tasks. + +### ID Generation Rules + +- `trace_id`: Derived from the correlation ID (workflow_run_id or node_execution_id for drafts) using `int(UUID(correlation_id))` +- `span_id`: Derived from the source ID using the lower 64 bits of `UUID(source_id)` + +### Scenario A: Simple Workflow + +A single workflow run with multiple nodes. All spans and logs share the same `trace_id` (derived from `workflow_run_id`). + +``` +trace_id = UUID(workflow_run_id) +├── [root span] dify.workflow.run (span_id = hash(workflow_run_id)) +│ ├── [child] dify.node.execution - "Start" (span_id = hash(node_exec_id_1)) +│ ├── [child] dify.node.execution - "LLM" (span_id = hash(node_exec_id_2)) +│ └── [child] dify.node.execution - "End" (span_id = hash(node_exec_id_3)) +``` + +### Scenario B: Nested Sub-Workflow + +A workflow calling another workflow via a Tool or Sub-workflow node. The child workflow's spans are linked to the parent via `parent_span_id`. Both workflows share the same trace_id. + +``` +trace_id = UUID(outer_workflow_run_id) ← shared across both workflows +├── [root] dify.workflow.run (outer) (span_id = hash(outer_workflow_run_id)) +│ ├── dify.node.execution - "Start Node" +│ ├── dify.node.execution - "Tool Node" (triggers sub-workflow) +│ │ └── [child] dify.workflow.run (inner) (span_id = hash(inner_workflow_run_id)) +│ │ ├── dify.node.execution - "Inner Start" +│ │ └── dify.node.execution - "Inner End" +│ └── dify.node.execution - "End Node" +``` + +**Key attributes for nested workflows:** + +- Inner workflow's `dify.parent.trace_id` = outer `workflow_run_id` +- Inner workflow's `dify.parent.node.execution_id` = tool node's `execution_id` +- Inner workflow's `dify.parent.workflow.run_id` = outer `workflow_run_id` +- Inner workflow's `dify.parent.app.id` = outer `app_id` + +### Scenario C: Draft Node Execution + +A single node run in isolation (debugger/preview mode). It creates its own trace where the node span is the root. + +``` +trace_id = UUID(node_execution_id) ← own trace, NOT part of any workflow +└── dify.node.execution.draft (span_id = hash(node_execution_id)) +``` + +**Key difference:** Draft executions use `node_execution_id` as the correlation_id, so they are NOT children of any workflow trace. + +## Content Gating + +When `ENTERPRISE_INCLUDE_CONTENT` is set to `false`, sensitive content attributes (inputs, outputs, queries) are replaced with reference strings (e.g., `ref:workflow_run_id=...`) to prevent data leakage to the OTEL collector. + +**Reference String Format:** + +``` +ref:{id_type}={uuid} +``` + +**Examples:** + +``` +ref:workflow_run_id=550e8400-e29b-41d4-a716-446655440000 +ref:node_execution_id=660e8400-e29b-41d4-a716-446655440001 +ref:message_id=770e8400-e29b-41d4-a716-446655440002 +``` + +To retrieve actual content when gating is enabled, query the Dify database using the provided UUID. + +## Reference + +For a complete list of telemetry signals, attributes, and data structures, see [DATA_DICTIONARY.md](./DATA_DICTIONARY.md). diff --git a/api/graphon/graph_engine/entities/__init__.py b/api/enterprise/telemetry/__init__.py similarity index 100% rename from api/graphon/graph_engine/entities/__init__.py rename to api/enterprise/telemetry/__init__.py diff --git a/api/enterprise/telemetry/contracts.py b/api/enterprise/telemetry/contracts.py new file mode 100644 index 0000000000..91398cb8cb --- /dev/null +++ b/api/enterprise/telemetry/contracts.py @@ -0,0 +1,73 @@ +"""Telemetry gateway contracts and data structures. + +This module defines the envelope format for telemetry events and the routing +configuration that determines how each event type is processed. +""" + +from __future__ import annotations + +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, ConfigDict + + +class TelemetryCase(StrEnum): + """Enumeration of all known telemetry event cases.""" + + WORKFLOW_RUN = "workflow_run" + NODE_EXECUTION = "node_execution" + DRAFT_NODE_EXECUTION = "draft_node_execution" + MESSAGE_RUN = "message_run" + TOOL_EXECUTION = "tool_execution" + MODERATION_CHECK = "moderation_check" + SUGGESTED_QUESTION = "suggested_question" + DATASET_RETRIEVAL = "dataset_retrieval" + GENERATE_NAME = "generate_name" + PROMPT_GENERATION = "prompt_generation" + APP_CREATED = "app_created" + APP_UPDATED = "app_updated" + APP_DELETED = "app_deleted" + FEEDBACK_CREATED = "feedback_created" + + +class SignalType(StrEnum): + """Signal routing type for telemetry cases.""" + + TRACE = "trace" + METRIC_LOG = "metric_log" + + +class CaseRoute(BaseModel): + """Routing configuration for a telemetry case. + + Attributes: + signal_type: The type of signal (trace or metric_log). + ce_eligible: Whether this case is eligible for community edition tracing. + """ + + signal_type: SignalType + ce_eligible: bool + + +class TelemetryEnvelope(BaseModel): + """Envelope for telemetry events. + + Attributes: + case: The telemetry case type. + tenant_id: The tenant identifier. + event_id: Unique event identifier for deduplication. + payload: The main event payload (inline for small payloads, + empty when offloaded to storage via ``payload_ref``). + metadata: Optional metadata dictionary. When the gateway + offloads a large payload to object storage, this contains + ``{"payload_ref": ""}``. + """ + + model_config = ConfigDict(extra="forbid", use_enum_values=False) + + case: TelemetryCase + tenant_id: str + event_id: str + payload: dict[str, Any] + metadata: dict[str, Any] | None = None diff --git a/api/enterprise/telemetry/draft_trace.py b/api/enterprise/telemetry/draft_trace.py new file mode 100644 index 0000000000..5a8d0ee6f4 --- /dev/null +++ b/api/enterprise/telemetry/draft_trace.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from graphon.enums import WorkflowNodeExecutionMetadataKey + +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit +from models.workflow import WorkflowNodeExecutionModel + + +def enqueue_draft_node_execution_trace( + *, + execution: WorkflowNodeExecutionModel, + outputs: Mapping[str, Any] | None, + workflow_execution_id: str | None, + user_id: str, +) -> None: + node_data = _build_node_execution_data( + execution=execution, + outputs=outputs, + workflow_execution_id=workflow_execution_id, + ) + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id=execution.tenant_id, + user_id=user_id, + app_id=execution.app_id, + ), + payload={"node_execution_data": node_data}, + ) + ) + + +def _build_node_execution_data( + *, + execution: WorkflowNodeExecutionModel, + outputs: Mapping[str, Any] | None, + workflow_execution_id: str | None, +) -> dict[str, Any]: + metadata = execution.execution_metadata_dict + node_outputs = outputs if outputs is not None else execution.outputs_dict + execution_id = workflow_execution_id or execution.workflow_run_id or execution.id + process_data = execution.process_data_dict or {} + + # Extract token breakdown from outputs.usage (set by LLM node) + usage: Mapping[str, Any] = {} + if isinstance(node_outputs, Mapping): + raw_usage = node_outputs.get("usage") + if isinstance(raw_usage, Mapping): + usage = raw_usage + + return { + "workflow_id": execution.workflow_id, + "workflow_execution_id": execution_id, + "tenant_id": execution.tenant_id, + "app_id": execution.app_id, + "node_execution_id": execution.id, + "node_id": execution.node_id, + "node_type": execution.node_type, + "title": execution.title, + "status": execution.status, + "error": execution.error, + "elapsed_time": execution.elapsed_time, + "index": execution.index, + "predecessor_node_id": execution.predecessor_node_id, + "created_at": execution.created_at, + "finished_at": execution.finished_at, + "total_tokens": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0), + "total_price": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0), + "currency": metadata.get(WorkflowNodeExecutionMetadataKey.CURRENCY), + "model_provider": process_data.get("model_provider"), + "model_name": process_data.get("model_name"), + "prompt_tokens": usage.get("prompt_tokens"), + "completion_tokens": usage.get("completion_tokens"), + "tool_name": (metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name") + if isinstance(metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict) + else None, + "iteration_id": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID), + "iteration_index": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX), + "loop_id": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID), + "loop_index": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX), + "parallel_id": metadata.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID), + "node_inputs": execution.inputs_dict, + "node_outputs": node_outputs, + "process_data": execution.process_data_dict, + } diff --git a/api/enterprise/telemetry/enterprise_trace.py b/api/enterprise/telemetry/enterprise_trace.py new file mode 100644 index 0000000000..fc17d9d93e --- /dev/null +++ b/api/enterprise/telemetry/enterprise_trace.py @@ -0,0 +1,966 @@ +"""Enterprise trace handler — duck-typed, NOT a BaseTraceInstance subclass. + +Invoked directly in the Celery task, not through OpsTraceManager dispatch. +Only requires a matching ``trace(trace_info)`` method signature. + +Signal strategy: +- **Traces (spans)**: workflow run, node execution, draft node execution only. +- **Metrics + structured logs**: all other event types. + +Token metric labels (unified structure): +All token metrics (dify.tokens.input, dify.tokens.output, dify.tokens.total) use the +same label set for consistent filtering and aggregation: +- tenant_id: Tenant identifier +- app_id: Application identifier +- operation_type: Source of token usage (workflow | node_execution | message | rule_generate | etc.) +- model_provider: LLM provider name (empty string if not applicable) +- model_name: LLM model name (empty string if not applicable) +- node_type: Workflow node type (empty string if not node_execution) + +This unified structure allows filtering by operation_type to separate: +- Workflow-level aggregates (operation_type=workflow) +- Individual node executions (operation_type=node_execution) +- Direct message calls (operation_type=message) +- Prompt generation operations (operation_type=rule_generate, code_generate, etc.) + +Without this, tokens are double-counted when querying totals (workflow totals include +node totals, since workflow.total_tokens is the sum of all node tokens). +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, cast + +from opentelemetry.util.types import AttributeValue + +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + OperationType, + PromptGenerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowNodeTraceInfo, + WorkflowTraceInfo, +) +from enterprise.telemetry.entities import ( + EnterpriseTelemetryCounter, + EnterpriseTelemetryEvent, + EnterpriseTelemetryHistogram, + EnterpriseTelemetrySpan, + TokenMetricLabels, +) +from enterprise.telemetry.telemetry_log import emit_metric_only_event, emit_telemetry_log + +logger = logging.getLogger(__name__) + + +class EnterpriseOtelTrace: + """Duck-typed enterprise trace handler. + + ``*_trace`` methods emit spans (workflow/node only) or structured logs + (all other events), plus metrics at 100 % accuracy. + """ + + def __init__(self) -> None: + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if exporter is None: + raise RuntimeError("EnterpriseOtelTrace instantiated but exporter is not initialized") + self._exporter = exporter + + def trace(self, trace_info: BaseTraceInfo) -> None: + if isinstance(trace_info, WorkflowTraceInfo): + self._workflow_trace(trace_info) + elif isinstance(trace_info, MessageTraceInfo): + self._message_trace(trace_info) + elif isinstance(trace_info, ToolTraceInfo): + self._tool_trace(trace_info) + elif isinstance(trace_info, DraftNodeExecutionTrace): + self._draft_node_execution_trace(trace_info) + elif isinstance(trace_info, WorkflowNodeTraceInfo): + self._node_execution_trace(trace_info) + elif isinstance(trace_info, ModerationTraceInfo): + self._moderation_trace(trace_info) + elif isinstance(trace_info, SuggestedQuestionTraceInfo): + self._suggested_question_trace(trace_info) + elif isinstance(trace_info, DatasetRetrievalTraceInfo): + self._dataset_retrieval_trace(trace_info) + elif isinstance(trace_info, GenerateNameTraceInfo): + self._generate_name_trace(trace_info) + elif isinstance(trace_info, PromptGenerationTraceInfo): + self._prompt_generation_trace(trace_info) + else: + raise AssertionError("this statment should be unreachable") + + def _common_attrs(self, trace_info: BaseTraceInfo) -> dict[str, Any]: + metadata = self._metadata(trace_info) + tenant_id, app_id, user_id = self._context_ids(trace_info, metadata) + return { + "dify.trace_id": trace_info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "gen_ai.user.id": user_id, + "dify.message.id": trace_info.message_id, + } + + def _metadata(self, trace_info: BaseTraceInfo) -> dict[str, Any]: + return trace_info.metadata + + def _context_ids( + self, + trace_info: BaseTraceInfo, + metadata: dict[str, Any], + ) -> tuple[str | None, str | None, str | None]: + tenant_id = getattr(trace_info, "tenant_id", None) or metadata.get("tenant_id") + app_id = getattr(trace_info, "app_id", None) or metadata.get("app_id") + user_id = getattr(trace_info, "user_id", None) or metadata.get("user_id") + return tenant_id, app_id, user_id + + def _labels(self, **values: AttributeValue) -> dict[str, AttributeValue]: + return dict(values) + + def _safe_payload_value(self, value: Any) -> str | dict[str, Any] | list[object] | None: + if isinstance(value, str): + return value + if isinstance(value, dict): + return cast(dict[str, Any], value) + if isinstance(value, list): + items: list[object] = [] + for item in cast(list[object], value): + items.append(item) + return items + return None + + def _content_or_ref(self, value: Any, ref: str) -> Any: + if self._exporter.include_content: + return self._maybe_json(value) + return ref + + def _maybe_json(self, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + try: + return json.dumps(value, default=str) + except (TypeError, ValueError): + return str(value) + + # ------------------------------------------------------------------ + # SPAN-emitting handlers (workflow, node execution, draft node) + # ------------------------------------------------------------------ + + def _workflow_trace(self, info: WorkflowTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + # -- Span attrs: identity + structure + status + timing + gen_ai scalars -- + span_attrs: dict[str, Any] = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.workflow.id": info.workflow_id, + "dify.workflow.run_id": info.workflow_run_id, + "dify.workflow.status": info.workflow_run_status, + "dify.workflow.error": info.error, + "dify.workflow.elapsed_time": info.workflow_run_elapsed_time, + "dify.invoke_from": metadata.get("triggered_from"), + "dify.conversation.id": info.conversation_id, + "dify.message.id": info.message_id, + "dify.invoked_by": info.invoked_by, + "gen_ai.usage.total_tokens": info.total_tokens, + "gen_ai.user.id": user_id, + } + + trace_correlation_override, parent_span_id_source = info.resolved_parent_context + + parent_ctx = metadata.get("parent_trace_context") + if isinstance(parent_ctx, dict): + parent_ctx_dict = cast(dict[str, Any], parent_ctx) + span_attrs["dify.parent.trace_id"] = parent_ctx_dict.get("trace_id") + span_attrs["dify.parent.node.execution_id"] = parent_ctx_dict.get("parent_node_execution_id") + span_attrs["dify.parent.workflow.run_id"] = parent_ctx_dict.get("parent_workflow_run_id") + span_attrs["dify.parent.app.id"] = parent_ctx_dict.get("parent_app_id") + + self._exporter.export_span( + EnterpriseTelemetrySpan.WORKFLOW_RUN, + span_attrs, + correlation_id=info.workflow_run_id, + span_id_source=info.workflow_run_id, + start_time=info.start_time, + end_time=info.end_time, + trace_correlation_override=trace_correlation_override, + parent_span_id_source=parent_span_id_source, + ) + + # -- Companion log: ALL attrs (span + detail) for full picture -- + log_attrs: dict[str, Any] = {**span_attrs} + log_attrs.update( + { + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "gen_ai.user.id": user_id, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.workflow.version": info.workflow_run_version, + } + ) + + ref = f"ref:workflow_run_id={info.workflow_run_id}" + log_attrs["dify.workflow.inputs"] = self._content_or_ref(info.workflow_run_inputs, ref) + log_attrs["dify.workflow.outputs"] = self._content_or_ref(info.workflow_run_outputs, ref) + log_attrs["dify.workflow.query"] = self._content_or_ref(info.query, ref) + + emit_telemetry_log( + event_name=EnterpriseTelemetryEvent.WORKFLOW_RUN, + attributes=log_attrs, + signal="span_detail", + trace_id_source=info.workflow_run_id, + span_id_source=info.workflow_run_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + # -- Metrics -- + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.WORKFLOW, + model_provider="", + model_name="", + node_type="", + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens is not None and info.prompt_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels) + if info.completion_tokens is not None and info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + invoke_from = metadata.get("triggered_from", "") + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="workflow", + status=info.workflow_run_status, + invoke_from=invoke_from, + ), + ) + # Prefer wall-clock timestamps over the elapsed_time field: elapsed_time defaults + # to 0 in the DB and can be stale if the Celery write races with the trace task. + # start_time = workflow_run.created_at, end_time = workflow_run.finished_at. + if info.start_time and info.end_time: + workflow_duration = (info.end_time - info.start_time).total_seconds() + elif info.workflow_run_elapsed_time: + workflow_duration = float(info.workflow_run_elapsed_time) + else: + workflow_duration = 0.0 + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.WORKFLOW_DURATION, + workflow_duration, + self._labels( + **labels, + status=info.workflow_run_status, + ), + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="workflow", + ), + ) + + def _node_execution_trace(self, info: WorkflowNodeTraceInfo) -> None: + self._emit_node_execution_trace(info, EnterpriseTelemetrySpan.NODE_EXECUTION, "node") + + def _draft_node_execution_trace(self, info: DraftNodeExecutionTrace) -> None: + self._emit_node_execution_trace( + info, + EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION, + "draft_node", + correlation_id_override=info.node_execution_id, + trace_correlation_override_param=info.workflow_run_id, + ) + + def _emit_node_execution_trace( + self, + info: WorkflowNodeTraceInfo, + span_name: EnterpriseTelemetrySpan, + request_type: str, + correlation_id_override: str | None = None, + trace_correlation_override_param: str | None = None, + ) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + # -- Span attrs: identity + structure + status + timing + gen_ai scalars -- + span_attrs: dict[str, Any] = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.workflow.id": info.workflow_id, + "dify.workflow.run_id": info.workflow_run_id, + "dify.message.id": info.message_id, + "dify.conversation.id": metadata.get("conversation_id"), + "dify.node.execution_id": info.node_execution_id, + "dify.node.id": info.node_id, + "dify.node.type": info.node_type, + "dify.node.title": info.title, + "dify.node.status": info.status, + "dify.node.error": info.error, + "dify.node.elapsed_time": info.elapsed_time, + "dify.node.index": info.index, + "dify.node.predecessor_node_id": info.predecessor_node_id, + "dify.node.iteration_id": info.iteration_id, + "dify.node.loop_id": info.loop_id, + "dify.node.parallel_id": info.parallel_id, + "dify.node.invoked_by": info.invoked_by, + "gen_ai.usage.input_tokens": info.prompt_tokens, + "gen_ai.usage.output_tokens": info.completion_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "gen_ai.request.model": info.model_name, + "gen_ai.provider.name": info.model_provider, + "gen_ai.user.id": user_id, + } + + resolved_override, _ = info.resolved_parent_context + trace_correlation_override = trace_correlation_override_param or resolved_override + + effective_correlation_id = correlation_id_override or info.workflow_run_id + self._exporter.export_span( + span_name, + span_attrs, + correlation_id=effective_correlation_id, + span_id_source=info.node_execution_id, + start_time=info.start_time, + end_time=info.end_time, + trace_correlation_override=trace_correlation_override, + ) + + # -- Companion log: ALL attrs (span + detail) -- + log_attrs: dict[str, Any] = {**span_attrs} + log_attrs.update( + { + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "dify.invoke_from": metadata.get("invoke_from"), + "gen_ai.user.id": user_id, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.node.total_price": info.total_price, + "dify.node.currency": info.currency, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_name, + "gen_ai.tool.name": info.tool_name, + "dify.node.iteration_index": info.iteration_index, + "dify.node.loop_index": info.loop_index, + "dify.plugin.name": metadata.get("plugin_name"), + "dify.credential.name": metadata.get("credential_name"), + "dify.credential.id": metadata.get("credential_id"), + "dify.dataset.ids": self._maybe_json(metadata.get("dataset_ids")), + "dify.dataset.names": self._maybe_json(metadata.get("dataset_names")), + } + ) + + ref = f"ref:node_execution_id={info.node_execution_id}" + log_attrs["dify.node.inputs"] = self._content_or_ref(info.node_inputs, ref) + log_attrs["dify.node.outputs"] = self._content_or_ref(info.node_outputs, ref) + log_attrs["dify.node.process_data"] = self._content_or_ref(info.process_data, ref) + + emit_telemetry_log( + event_name=span_name.value, + attributes=log_attrs, + signal="span_detail", + trace_id_source=info.workflow_run_id, + span_id_source=info.node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + # -- Metrics -- + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + node_type=info.node_type, + model_provider=info.model_provider or "", + ) + if info.total_tokens: + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.NODE_EXECUTION, + model_provider=info.model_provider or "", + model_name=info.model_name or "", + node_type=info.node_type, + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens is not None and info.prompt_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels + ) + if info.completion_tokens is not None and info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type=request_type, + status=info.status, + model_name=info.model_name or "", + ), + ) + duration_labels = dict(labels) + duration_labels["model_name"] = info.model_name or "" + plugin_name = metadata.get("plugin_name") + if plugin_name and info.node_type in {"tool", "knowledge-retrieval"}: + duration_labels["plugin_name"] = plugin_name + self._exporter.record_histogram(EnterpriseTelemetryHistogram.NODE_DURATION, info.elapsed_time, duration_labels) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type=request_type, + model_name=info.model_name or "", + ), + ) + + # ------------------------------------------------------------------ + # METRIC-ONLY handlers (structured log + counters/histograms) + # ------------------------------------------------------------------ + + def _message_trace(self, info: MessageTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.invoke_from": metadata.get("from_source"), + "dify.conversation.id": metadata.get("conversation_id"), + "dify.conversation.mode": info.conversation_mode, + "gen_ai.provider.name": metadata.get("ls_provider"), + "gen_ai.request.model": metadata.get("ls_model_name"), + "gen_ai.usage.input_tokens": info.message_tokens, + "gen_ai.usage.output_tokens": info.answer_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.message.status": metadata.get("status"), + "dify.message.error": info.error, + "dify.message.from_source": metadata.get("from_source"), + "dify.message.from_end_user_id": metadata.get("from_end_user_id"), + "dify.message.from_account_id": metadata.get("from_account_id"), + "dify.streaming": info.is_streaming_request, + "dify.message.time_to_first_token": info.gen_ai_server_time_to_first_token, + "dify.message.streaming_duration": info.llm_streaming_time_to_generate, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + + if info.start_time and info.end_time: + attrs["dify.message.duration"] = (info.end_time - info.start_time).total_seconds() + + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:message_id={info.message_id}" + inputs = self._safe_payload_value(info.inputs) + outputs = self._safe_payload_value(info.outputs) + attrs["dify.message.inputs"] = self._content_or_ref(inputs, ref) + attrs["dify.message.outputs"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.MESSAGE_RUN, + attributes=attrs, + trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None), + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + model_provider=metadata.get("ls_provider") or "", + model_name=metadata.get("ls_model_name") or "", + ) + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.MESSAGE, + model_provider=metadata.get("ls_provider") or "", + model_name=metadata.get("ls_model_name") or "", + node_type="", + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.message_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.message_tokens, token_labels) + if info.answer_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.answer_tokens, token_labels) + invoke_from = metadata.get("from_source", "") + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="message", + status=metadata.get("status", ""), + invoke_from=invoke_from, + ), + ) + + if info.start_time and info.end_time: + duration = (info.end_time - info.start_time).total_seconds() + self._exporter.record_histogram(EnterpriseTelemetryHistogram.MESSAGE_DURATION, duration, labels) + + if info.gen_ai_server_time_to_first_token is not None: + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.MESSAGE_TTFT, info.gen_ai_server_time_to_first_token, labels + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="message", + ), + ) + + def _tool_trace(self, info: ToolTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.tool.name": info.tool_name, + "dify.tool.duration": float(info.time_cost), + "dify.tool.status": "failed" if info.error else "succeeded", + "dify.tool.error": info.error, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:message_id={info.message_id}" + attrs["dify.tool.inputs"] = self._content_or_ref(info.tool_inputs, ref) + attrs["dify.tool.outputs"] = self._content_or_ref(info.tool_outputs, ref) + attrs["dify.tool.parameters"] = self._content_or_ref(info.tool_parameters, ref) + attrs["dify.tool.config"] = self._content_or_ref(info.tool_config, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.TOOL_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + tool_name=info.tool_name, + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="tool", + ), + ) + self._exporter.record_histogram(EnterpriseTelemetryHistogram.TOOL_DURATION, float(info.time_cost), labels) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="tool", + ), + ) + + def _moderation_trace(self, info: ModerationTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.moderation.flagged": info.flagged, + "dify.moderation.action": info.action, + "dify.moderation.preset_response": info.preset_response, + "dify.moderation.type": metadata.get("moderation_type", "input"), + "dify.moderation.categories": self._maybe_json(metadata.get("moderation_categories", [])), + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + attrs["dify.moderation.query"] = self._content_or_ref( + info.query, + f"ref:message_id={info.message_id}", + ) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.MODERATION_CHECK, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="moderation", + ), + ) + + def _suggested_question_trace(self, info: SuggestedQuestionTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + duration: float | None = None + if info.start_time is not None and info.end_time is not None: + duration = (info.end_time - info.start_time).total_seconds() + error = info.error or (info.metadata.get("error") if info.metadata else None) + status = "failed" if error else (info.status or "succeeded") + attrs.update( + { + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.suggested_question.status": status, + "dify.suggested_question.error": error, + "dify.suggested_question.duration": duration, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_id, + "dify.suggested_question.count": len(info.suggested_question), + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + attrs["dify.suggested_question.questions"] = self._content_or_ref( + info.suggested_question, + f"ref:message_id={info.message_id}", + ) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="suggested_question", + model_provider=info.model_provider or "", + model_name=info.model_id or "", + ), + ) + + def _dataset_retrieval_trace(self, info: DatasetRetrievalTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs["dify.retrieval.error"] = info.error + attrs["dify.retrieval.status"] = "failed" if info.error else "succeeded" + if info.start_time and info.end_time: + attrs["dify.retrieval.duration"] = (info.end_time - info.start_time).total_seconds() + attrs["dify.workflow.run_id"] = metadata.get("workflow_run_id") + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + docs: list[dict[str, Any]] = [] + documents_any: Any = info.documents + documents_list: list[Any] = cast(list[Any], documents_any) if isinstance(documents_any, list) else [] + for entry in documents_list: + if isinstance(entry, dict): + entry_dict: dict[str, Any] = cast(dict[str, Any], entry) + docs.append(entry_dict) + dataset_ids: list[str] = [] + dataset_names: list[str] = [] + structured_docs: list[dict[str, Any]] = [] + for doc in docs: + meta_raw = doc.get("metadata") + meta: dict[str, Any] = cast(dict[str, Any], meta_raw) if isinstance(meta_raw, dict) else {} + did = meta.get("dataset_id") + dname = meta.get("dataset_name") + if did and did not in dataset_ids: + dataset_ids.append(did) + if dname and dname not in dataset_names: + dataset_names.append(dname) + structured_docs.append( + { + "dataset_id": did, + "document_id": meta.get("document_id"), + "segment_id": meta.get("segment_id"), + "score": meta.get("score"), + } + ) + + attrs["dify.dataset.id"] = self._maybe_json(dataset_ids) + attrs["dify.dataset.name"] = self._maybe_json(dataset_names) + attrs["dify.retrieval.document_count"] = len(docs) + + embedding_models_raw: Any = metadata.get("embedding_models") + embedding_models: dict[str, Any] = ( + cast(dict[str, Any], embedding_models_raw) if isinstance(embedding_models_raw, dict) else {} + ) + if embedding_models: + providers: list[str] = [] + models: list[str] = [] + for ds_info in embedding_models.values(): + if isinstance(ds_info, dict): + ds_info_dict: dict[str, Any] = cast(dict[str, Any], ds_info) + p = ds_info_dict.get("embedding_model_provider", "") + m = ds_info_dict.get("embedding_model", "") + if p and p not in providers: + providers.append(p) + if m and m not in models: + models.append(m) + attrs["dify.dataset.embedding_providers"] = self._maybe_json(providers) + attrs["dify.dataset.embedding_models"] = self._maybe_json(models) + + # Add rerank model to logs + rerank_provider = metadata.get("rerank_model_provider", "") + rerank_model = metadata.get("rerank_model_name", "") + if rerank_provider or rerank_model: + attrs["dify.retrieval.rerank_provider"] = rerank_provider + attrs["dify.retrieval.rerank_model"] = rerank_model + + ref = f"ref:message_id={info.message_id}" + retrieval_inputs = self._safe_payload_value(info.inputs) + attrs["dify.retrieval.query"] = self._content_or_ref(retrieval_inputs, ref) + attrs["dify.dataset.documents"] = self._content_or_ref(structured_docs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.DATASET_RETRIEVAL, + attributes=attrs, + trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None), + span_id_source=node_execution_id or (str(info.message_id) if info.message_id else None), + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="dataset_retrieval", + ), + ) + + for did in dataset_ids: + # Get embedding model for this specific dataset + ds_embedding_info = embedding_models.get(did, {}) + embedding_provider = ds_embedding_info.get("embedding_model_provider", "") + embedding_model = ds_embedding_info.get("embedding_model", "") + + # Get rerank model (same for all datasets in this retrieval) + rerank_provider = metadata.get("rerank_model_provider", "") + rerank_model = metadata.get("rerank_model_name", "") + + self._exporter.increment_counter( + EnterpriseTelemetryCounter.DATASET_RETRIEVALS, + 1, + self._labels( + **labels, + dataset_id=did, + embedding_model_provider=embedding_provider, + embedding_model=embedding_model, + rerank_model_provider=rerank_provider, + rerank_model=rerank_model, + ), + ) + + def _generate_name_trace(self, info: GenerateNameTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs["dify.conversation.id"] = info.conversation_id + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + duration: float | None = None + if info.start_time is not None and info.end_time is not None: + duration = (info.end_time - info.start_time).total_seconds() + error: str | None = metadata.get("error") if metadata else None + status = "failed" if error else "succeeded" + attrs["dify.generate_name.duration"] = duration + attrs["dify.generate_name.status"] = status + attrs["dify.generate_name.error"] = error + + ref = f"ref:conversation_id={info.conversation_id}" + inputs = self._safe_payload_value(info.inputs) + outputs = self._safe_payload_value(info.outputs) + attrs["dify.generate_name.inputs"] = self._content_or_ref(inputs, ref) + attrs["dify.generate_name.outputs"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="generate_name", + ), + ) + + def _prompt_generation_trace(self, info: PromptGenerationTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "gen_ai.user.id": user_id, + "dify.app_id": app_id or "", + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "dify.prompt_generation.operation_type": info.operation_type, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_name, + "gen_ai.usage.input_tokens": info.prompt_tokens, + "gen_ai.usage.output_tokens": info.completion_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.prompt_generation.duration": info.latency, + "dify.prompt_generation.status": "failed" if info.error else "succeeded", + "dify.prompt_generation.error": info.error, + } + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + if info.total_price is not None: + attrs["dify.prompt_generation.total_price"] = info.total_price + attrs["dify.prompt_generation.currency"] = info.currency + + ref = f"ref:trace_id={info.trace_id}" + outputs = self._safe_payload_value(info.outputs) + attrs["dify.prompt_generation.instruction"] = self._content_or_ref(info.instruction, ref) + attrs["dify.prompt_generation.output"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=info.operation_type, + model_provider=info.model_provider, + model_name=info.model_name, + node_type="", + ).to_dict() + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=info.operation_type, + model_provider=info.model_provider, + model_name=info.model_name, + ) + + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels) + if info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + + prompt_status = "failed" if info.error else "succeeded" + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="prompt_generation", + status=prompt_status, + ), + ) + + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION, + info.latency, + labels, + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="prompt_generation", + ), + ) diff --git a/api/enterprise/telemetry/entities/__init__.py b/api/enterprise/telemetry/entities/__init__.py new file mode 100644 index 0000000000..4a9bd3dbf8 --- /dev/null +++ b/api/enterprise/telemetry/entities/__init__.py @@ -0,0 +1,121 @@ +from enum import StrEnum +from typing import cast + +from opentelemetry.util.types import AttributeValue +from pydantic import BaseModel, ConfigDict + + +class EnterpriseTelemetrySpan(StrEnum): + WORKFLOW_RUN = "dify.workflow.run" + NODE_EXECUTION = "dify.node.execution" + DRAFT_NODE_EXECUTION = "dify.node.execution.draft" + + +class EnterpriseTelemetryEvent(StrEnum): + """Event names for enterprise telemetry logs.""" + + APP_CREATED = "dify.app.created" + APP_UPDATED = "dify.app.updated" + APP_DELETED = "dify.app.deleted" + FEEDBACK_CREATED = "dify.feedback.created" + WORKFLOW_RUN = "dify.workflow.run" + MESSAGE_RUN = "dify.message.run" + TOOL_EXECUTION = "dify.tool.execution" + MODERATION_CHECK = "dify.moderation.check" + SUGGESTED_QUESTION_GENERATION = "dify.suggested_question.generation" + DATASET_RETRIEVAL = "dify.dataset.retrieval" + GENERATE_NAME_EXECUTION = "dify.generate_name.execution" + PROMPT_GENERATION_EXECUTION = "dify.prompt_generation.execution" + REHYDRATION_FAILED = "dify.telemetry.rehydration_failed" + + +class EnterpriseTelemetryCounter(StrEnum): + TOKENS = "tokens" + INPUT_TOKENS = "input_tokens" + OUTPUT_TOKENS = "output_tokens" + REQUESTS = "requests" + ERRORS = "errors" + FEEDBACK = "feedback" + DATASET_RETRIEVALS = "dataset_retrievals" + APP_CREATED = "app_created" + APP_UPDATED = "app_updated" + APP_DELETED = "app_deleted" + + +class EnterpriseTelemetryHistogram(StrEnum): + WORKFLOW_DURATION = "workflow_duration" + NODE_DURATION = "node_duration" + MESSAGE_DURATION = "message_duration" + MESSAGE_TTFT = "message_ttft" + TOOL_DURATION = "tool_duration" + PROMPT_GENERATION_DURATION = "prompt_generation_duration" + + +class TokenMetricLabels(BaseModel): + """Unified label structure for all dify.token.* metrics. + + All token counters (dify.tokens.input, dify.tokens.output, dify.tokens.total) MUST + use this exact label set to ensure consistent filtering and aggregation across + different operation types. + + Attributes: + tenant_id: Tenant identifier. + app_id: Application identifier. + operation_type: Source of token usage (workflow | node_execution | message | + rule_generate | code_generate | structured_output | instruction_modify). + model_provider: LLM provider name. Empty string if not applicable (e.g., workflow-level). + model_name: LLM model name. Empty string if not applicable (e.g., workflow-level). + node_type: Workflow node type. Empty string unless operation_type=node_execution. + + Usage: + labels = TokenMetricLabels( + tenant_id="tenant-123", + app_id="app-456", + operation_type=OperationType.WORKFLOW, + model_provider="", + model_name="", + node_type="", + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.INPUT_TOKENS, + 100, + labels.to_dict() + ) + + Design rationale: + Without this unified structure, tokens get double-counted when querying totals + because workflow.total_tokens is already the sum of all node tokens. The + operation_type label allows filtering to separate workflow-level aggregates from + node-level detail, while keeping the same label cardinality for consistent queries. + """ + + tenant_id: str + app_id: str + operation_type: str + model_provider: str + model_name: str + node_type: str + + model_config = ConfigDict(extra="forbid", frozen=True) + + def to_dict(self) -> dict[str, AttributeValue]: + return cast( + dict[str, AttributeValue], + { + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "operation_type": self.operation_type, + "model_provider": self.model_provider, + "model_name": self.model_name, + "node_type": self.node_type, + }, + ) + + +__all__ = [ + "EnterpriseTelemetryCounter", + "EnterpriseTelemetryEvent", + "EnterpriseTelemetryHistogram", + "EnterpriseTelemetrySpan", + "TokenMetricLabels", +] diff --git a/api/enterprise/telemetry/event_handlers.py b/api/enterprise/telemetry/event_handlers.py new file mode 100644 index 0000000000..d8b4208c69 --- /dev/null +++ b/api/enterprise/telemetry/event_handlers.py @@ -0,0 +1,72 @@ +"""Blinker signal handlers for enterprise telemetry. + +Registered at import time via ``@signal.connect`` decorators. +Import must happen during ``ext_enterprise_telemetry.init_app()`` to +ensure handlers fire. Each handler delegates to ``core.telemetry.gateway`` +which handles routing, EE-gating, and dispatch. + +All handlers are best-effort: exceptions are caught and logged so that +telemetry failures never break user-facing operations. +""" + +from __future__ import annotations + +import logging + +from events.app_event import app_was_created, app_was_deleted, app_was_updated + +logger = logging.getLogger(__name__) + +__all__ = [ + "_handle_app_created", + "_handle_app_deleted", + "_handle_app_updated", +] + + +@app_was_created.connect +def _handle_app_created(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_CREATED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={ + "app_id": getattr(sender, "id", None), + "mode": getattr(sender, "mode", None), + }, + ) + except Exception: + logger.warning("Failed to emit app_created telemetry", exc_info=True) + + +@app_was_updated.connect +def _handle_app_updated(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_UPDATED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={"app_id": getattr(sender, "id", None)}, + ) + except Exception: + logger.warning("Failed to emit app_updated telemetry", exc_info=True) + + +@app_was_deleted.connect +def _handle_app_deleted(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_DELETED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={"app_id": getattr(sender, "id", None)}, + ) + except Exception: + logger.warning("Failed to emit app_deleted telemetry", exc_info=True) diff --git a/api/enterprise/telemetry/exporter.py b/api/enterprise/telemetry/exporter.py new file mode 100644 index 0000000000..b2f860764f --- /dev/null +++ b/api/enterprise/telemetry/exporter.py @@ -0,0 +1,283 @@ +"""Enterprise OTEL exporter — shared by EnterpriseOtelTrace, event handlers, and direct instrumentation. + +Uses dedicated TracerProvider and MeterProvider instances (configurable sampling, +independent from ext_otel.py infrastructure). + +Initialized once during Flask extension init (single-threaded via ext_enterprise_telemetry.py). +Accessed via ``ext_enterprise_telemetry.get_enterprise_exporter()`` from any thread/process. +""" + +import logging +import socket +import uuid +from datetime import UTC, datetime +from typing import Any, cast + +from opentelemetry import trace +from opentelemetry.baggage import get_all +from opentelemetry.baggage.propagation import W3CBaggagePropagator +from opentelemetry.context import Context +from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio +from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.trace import SpanContext, TraceFlags +from opentelemetry.util.types import Attributes, AttributeValue + +from configs import dify_config +from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram +from enterprise.telemetry.id_generator import ( + CorrelationIdGenerator, + compute_deterministic_span_id, + set_correlation_id, + set_span_id_source, +) + +logger = logging.getLogger(__name__) + + +def is_enterprise_telemetry_enabled() -> bool: + return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED) + + +def _parse_otlp_headers(raw: str) -> dict[str, str]: + ctx = W3CBaggagePropagator().extract({"baggage": raw}) + return {k: v for k, v in get_all(ctx).items() if isinstance(v, str)} + + +def _datetime_to_ns(dt: datetime) -> int: + """Convert a datetime to nanoseconds since epoch (OTEL convention).""" + # Ensure we always interpret naive datetimes as UTC instead of local time. + if dt.tzinfo is None: + dt = dt.replace(tzinfo=UTC) + else: + dt = dt.astimezone(UTC) + return int(dt.timestamp() * 1_000_000_000) + + +class _ExporterFactory: + def __init__(self, protocol: str, endpoint: str, headers: dict[str, str], insecure: bool): + self._protocol = protocol + self._endpoint = endpoint + self._headers = headers + self._grpc_headers = tuple(headers.items()) if headers else None + self._http_headers = headers or None + self._insecure = insecure + + def create_trace_exporter(self) -> HTTPSpanExporter | GRPCSpanExporter: + if self._protocol == "grpc": + return GRPCSpanExporter( + endpoint=self._endpoint or None, + headers=self._grpc_headers, + insecure=self._insecure, + ) + trace_endpoint = f"{self._endpoint}/v1/traces" if self._endpoint else "" + return HTTPSpanExporter(endpoint=trace_endpoint or None, headers=self._http_headers) + + def create_metric_exporter(self) -> HTTPMetricExporter | GRPCMetricExporter: + if self._protocol == "grpc": + return GRPCMetricExporter( + endpoint=self._endpoint or None, + headers=self._grpc_headers, + insecure=self._insecure, + ) + metric_endpoint = f"{self._endpoint}/v1/metrics" if self._endpoint else "" + return HTTPMetricExporter(endpoint=metric_endpoint or None, headers=self._http_headers) + + +class EnterpriseExporter: + """Shared OTEL exporter for all enterprise telemetry. + + ``export_span`` creates spans with optional real timestamps, deterministic + span/trace IDs, and cross-workflow parent linking. + ``increment_counter`` / ``record_histogram`` emit OTEL metrics at 100% accuracy. + """ + + def __init__(self, config: object) -> None: + endpoint: str = getattr(config, "ENTERPRISE_OTLP_ENDPOINT", "") + headers_raw: str = getattr(config, "ENTERPRISE_OTLP_HEADERS", "") + protocol: str = (getattr(config, "ENTERPRISE_OTLP_PROTOCOL", "http") or "http").lower() + service_name: str = getattr(config, "ENTERPRISE_SERVICE_NAME", "dify") + sampling_rate: float = getattr(config, "ENTERPRISE_OTEL_SAMPLING_RATE", 1.0) + self.include_content: bool = getattr(config, "ENTERPRISE_INCLUDE_CONTENT", True) + api_key: str = getattr(config, "ENTERPRISE_OTLP_API_KEY", "") + + # Auto-detect TLS: https:// uses secure, everything else is insecure + insecure = not endpoint.startswith("https://") + + resource = Resource( + attributes={ + ResourceAttributes.SERVICE_NAME: service_name, + ResourceAttributes.HOST_NAME: socket.gethostname(), + } + ) + sampler = ParentBasedTraceIdRatio(sampling_rate) + id_generator = CorrelationIdGenerator() + self._tracer_provider = TracerProvider(resource=resource, sampler=sampler, id_generator=id_generator) + + headers = _parse_otlp_headers(headers_raw) + if api_key: + if "authorization" in headers: + logger.warning( + "ENTERPRISE_OTLP_API_KEY is set but ENTERPRISE_OTLP_HEADERS also contains " + "'authorization'; the API key will take precedence." + ) + headers["authorization"] = f"Bearer {api_key}" + factory = _ExporterFactory(protocol, endpoint, headers, insecure=insecure) + + trace_exporter = factory.create_trace_exporter() + self._tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) + self._tracer = self._tracer_provider.get_tracer("dify.enterprise") + + metric_exporter = factory.create_metric_exporter() + self._meter_provider = MeterProvider( + resource=resource, + metric_readers=[PeriodicExportingMetricReader(metric_exporter)], + ) + meter = self._meter_provider.get_meter("dify.enterprise") + self._counters = { + EnterpriseTelemetryCounter.TOKENS: meter.create_counter("dify.tokens.total", unit="{token}"), + EnterpriseTelemetryCounter.INPUT_TOKENS: meter.create_counter("dify.tokens.input", unit="{token}"), + EnterpriseTelemetryCounter.OUTPUT_TOKENS: meter.create_counter("dify.tokens.output", unit="{token}"), + EnterpriseTelemetryCounter.REQUESTS: meter.create_counter("dify.requests.total", unit="{request}"), + EnterpriseTelemetryCounter.ERRORS: meter.create_counter("dify.errors.total", unit="{error}"), + EnterpriseTelemetryCounter.FEEDBACK: meter.create_counter("dify.feedback.total", unit="{feedback}"), + EnterpriseTelemetryCounter.DATASET_RETRIEVALS: meter.create_counter( + "dify.dataset.retrievals.total", unit="{retrieval}" + ), + EnterpriseTelemetryCounter.APP_CREATED: meter.create_counter("dify.app.created.total", unit="{app}"), + EnterpriseTelemetryCounter.APP_UPDATED: meter.create_counter("dify.app.updated.total", unit="{app}"), + EnterpriseTelemetryCounter.APP_DELETED: meter.create_counter("dify.app.deleted.total", unit="{app}"), + } + self._histograms = { + EnterpriseTelemetryHistogram.WORKFLOW_DURATION: meter.create_histogram("dify.workflow.duration", unit="s"), + EnterpriseTelemetryHistogram.NODE_DURATION: meter.create_histogram("dify.node.duration", unit="s"), + EnterpriseTelemetryHistogram.MESSAGE_DURATION: meter.create_histogram("dify.message.duration", unit="s"), + EnterpriseTelemetryHistogram.MESSAGE_TTFT: meter.create_histogram( + "dify.message.time_to_first_token", unit="s" + ), + EnterpriseTelemetryHistogram.TOOL_DURATION: meter.create_histogram("dify.tool.duration", unit="s"), + EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION: meter.create_histogram( + "dify.prompt_generation.duration", unit="s" + ), + } + + def export_span( + self, + name: str, + attributes: dict[str, Any], + correlation_id: str | None = None, + span_id_source: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + trace_correlation_override: str | None = None, + parent_span_id_source: str | None = None, + ) -> None: + """Export an OTEL span with optional deterministic IDs and real timestamps. + + Args: + name: Span operation name. + attributes: Span attributes dict. + correlation_id: Source for trace_id derivation (groups spans in one trace). + span_id_source: Source for deterministic span_id (e.g. workflow_run_id or node_execution_id). + start_time: Real span start time. When None, uses current time. + end_time: Real span end time. When None, span ends immediately. + trace_correlation_override: Override trace_id source (for cross-workflow linking). + When set, trace_id is derived from this instead of ``correlation_id``. + parent_span_id_source: Override parent span_id source (for cross-workflow linking). + When set, parent span_id is derived from this value. When None and + ``correlation_id`` is set, parent is the workflow root span. + """ + effective_trace_correlation = trace_correlation_override or correlation_id + set_correlation_id(effective_trace_correlation) + set_span_id_source(span_id_source) + + try: + parent_context: Context | None = None + # A span is the "root" of its correlation group when span_id_source == correlation_id + # (i.e. a workflow root span). All other spans are children. + if parent_span_id_source: + # Cross-workflow linking: parent is an explicit span (e.g. tool node in outer workflow) + parent_span_id = compute_deterministic_span_id(parent_span_id_source) + try: + parent_trace_id = int(uuid.UUID(effective_trace_correlation)) if effective_trace_correlation else 0 + except (ValueError, AttributeError): + logger.warning( + "Invalid trace correlation UUID for cross-workflow link: %s, span=%s", + effective_trace_correlation, + name, + ) + parent_trace_id = 0 + if parent_trace_id: + parent_span_context = SpanContext( + trace_id=parent_trace_id, + span_id=parent_span_id, + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context)) + elif correlation_id and correlation_id != span_id_source: + # Child span: parent is the correlation-group root (workflow root span) + parent_span_id = compute_deterministic_span_id(correlation_id) + try: + parent_trace_id = int(uuid.UUID(effective_trace_correlation or correlation_id)) + except (ValueError, AttributeError): + logger.warning( + "Invalid trace correlation UUID for child span link: %s, span=%s", + effective_trace_correlation or correlation_id, + name, + ) + parent_trace_id = 0 + if parent_trace_id: + parent_span_context = SpanContext( + trace_id=parent_trace_id, + span_id=parent_span_id, + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context)) + + span_start_time = _datetime_to_ns(start_time) if start_time is not None else None + span_end_on_exit = end_time is None + + with self._tracer.start_as_current_span( + name, + context=parent_context, + start_time=span_start_time, + end_on_exit=span_end_on_exit, + ) as span: + for key, value in attributes.items(): + if value is not None: + span.set_attribute(key, value) + if end_time is not None: + span.end(end_time=_datetime_to_ns(end_time)) + except Exception: + logger.exception("Failed to export span %s", name) + finally: + set_correlation_id(None) + set_span_id_source(None) + + def increment_counter( + self, name: EnterpriseTelemetryCounter, value: int, labels: dict[str, AttributeValue] + ) -> None: + counter = self._counters.get(name) + if counter: + counter.add(value, cast(Attributes, labels)) + + def record_histogram( + self, name: EnterpriseTelemetryHistogram, value: float, labels: dict[str, AttributeValue] + ) -> None: + histogram = self._histograms.get(name) + if histogram: + histogram.record(value, cast(Attributes, labels)) + + def shutdown(self) -> None: + self._tracer_provider.shutdown() + self._meter_provider.shutdown() diff --git a/api/enterprise/telemetry/id_generator.py b/api/enterprise/telemetry/id_generator.py new file mode 100644 index 0000000000..f3e5d6d0d6 --- /dev/null +++ b/api/enterprise/telemetry/id_generator.py @@ -0,0 +1,75 @@ +"""Custom OTEL ID Generator for correlation-based trace/span ID derivation. + +Uses contextvars for thread-safe correlation_id -> trace_id mapping. +When a span_id_source is set, the span_id is derived deterministically +from that value, enabling any span to reference another as parent +without depending on span creation order. +""" + +import random +import uuid +from contextvars import ContextVar + +from opentelemetry.sdk.trace.id_generator import IdGenerator + +_correlation_id_context: ContextVar[str | None] = ContextVar("correlation_id", default=None) +_span_id_source_context: ContextVar[str | None] = ContextVar("span_id_source", default=None) + + +def set_correlation_id(correlation_id: str | None) -> None: + _correlation_id_context.set(correlation_id) + + +def get_correlation_id() -> str | None: + return _correlation_id_context.get() + + +def set_span_id_source(source_id: str | None) -> None: + """Set the source for deterministic span_id generation. + + When set, ``generate_span_id()`` derives the span_id from this value + (lower 64 bits of the UUID). Pass the ``workflow_run_id`` for workflow + root spans or ``node_execution_id`` for node spans. + """ + _span_id_source_context.set(source_id) + + +def compute_deterministic_span_id(source_id: str) -> int: + """Derive a deterministic span_id from any UUID string. + + Uses the lower 64 bits of the UUID, guaranteeing non-zero output + (OTEL requires span_id != 0). + """ + span_id = uuid.UUID(source_id).int & ((1 << 64) - 1) + return span_id if span_id != 0 else 1 + + +class CorrelationIdGenerator(IdGenerator): + """ID generator that derives trace_id and optionally span_id from context. + + - trace_id: always derived from correlation_id (groups all spans in one trace) + - span_id: derived from span_id_source when set (enables deterministic + parent-child linking), otherwise random + """ + + def generate_trace_id(self) -> int: + correlation_id = _correlation_id_context.get() + if correlation_id: + try: + return uuid.UUID(correlation_id).int + except (ValueError, AttributeError): + pass + return random.getrandbits(128) + + def generate_span_id(self) -> int: + source = _span_id_source_context.get() + if source: + try: + return compute_deterministic_span_id(source) + except (ValueError, AttributeError): + pass + + span_id = random.getrandbits(64) + while span_id == 0: + span_id = random.getrandbits(64) + return span_id diff --git a/api/enterprise/telemetry/metric_handler.py b/api/enterprise/telemetry/metric_handler.py new file mode 100644 index 0000000000..ffd9a7e2b5 --- /dev/null +++ b/api/enterprise/telemetry/metric_handler.py @@ -0,0 +1,421 @@ +"""Enterprise metric/log event handler. + +This module processes metric and log telemetry events after they've been +dequeued from the enterprise_telemetry Celery queue. It handles case routing, +idempotency checking, and payload rehydration. +""" + +from __future__ import annotations + +import json +import logging +from datetime import UTC, datetime +from typing import Any + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from extensions.ext_redis import redis_client +from extensions.ext_storage import storage + +logger = logging.getLogger(__name__) + + +class EnterpriseMetricHandler: + """Handler for enterprise metric and log telemetry events. + + Processes envelopes from the enterprise_telemetry queue, routing each + case to the appropriate handler method. Implements idempotency checking + and payload rehydration with fallback. + """ + + def _increment_diagnostic_counter(self, counter_name: str, labels: dict[str, str] | None = None) -> None: + """Increment a diagnostic counter for operational monitoring. + + Args: + counter_name: Name of the counter (e.g., 'processed_total', 'deduped_total'). + labels: Optional labels for the counter. + """ + try: + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + return + + full_counter_name = f"enterprise_telemetry.handler.{counter_name}" + logger.debug( + "Diagnostic counter: %s, labels=%s", + full_counter_name, + labels or {}, + ) + except Exception: + logger.debug("Failed to increment diagnostic counter: %s", counter_name, exc_info=True) + + def handle(self, envelope: TelemetryEnvelope) -> None: + """Main entry point for processing telemetry envelopes. + + Args: + envelope: The telemetry envelope to process. + """ + # Check for duplicate events + if self._is_duplicate(envelope): + logger.debug( + "Skipping duplicate event: tenant_id=%s, event_id=%s", + envelope.tenant_id, + envelope.event_id, + ) + self._increment_diagnostic_counter("deduped_total") + return + + # Route to appropriate handler based on case + case = envelope.case + if case == TelemetryCase.APP_CREATED: + self._on_app_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_created"}) + elif case == TelemetryCase.APP_UPDATED: + self._on_app_updated(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_updated"}) + elif case == TelemetryCase.APP_DELETED: + self._on_app_deleted(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"}) + elif case == TelemetryCase.FEEDBACK_CREATED: + self._on_feedback_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"}) + elif case == TelemetryCase.MESSAGE_RUN: + self._on_message_run(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "message_run"}) + elif case == TelemetryCase.TOOL_EXECUTION: + self._on_tool_execution(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"}) + elif case == TelemetryCase.MODERATION_CHECK: + self._on_moderation_check(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"}) + elif case == TelemetryCase.SUGGESTED_QUESTION: + self._on_suggested_question(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"}) + elif case == TelemetryCase.DATASET_RETRIEVAL: + self._on_dataset_retrieval(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"}) + elif case == TelemetryCase.GENERATE_NAME: + self._on_generate_name(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "generate_name"}) + elif case == TelemetryCase.PROMPT_GENERATION: + self._on_prompt_generation(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"}) + else: + logger.warning( + "Unknown telemetry case: %s (tenant_id=%s, event_id=%s)", + case, + envelope.tenant_id, + envelope.event_id, + ) + + def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool: + """Check if this event has already been processed. + + Uses Redis with TTL for deduplication. Returns True if duplicate, + False if first time seeing this event. + + Args: + envelope: The telemetry envelope to check. + + Returns: + True if this event_id has been seen before, False otherwise. + """ + dedup_key = f"telemetry:dedup:{envelope.tenant_id}:{envelope.event_id}" + + try: + # Atomic set-if-not-exists with 1h TTL + # Returns True if key was set (first time), None if already exists (duplicate) + was_set = redis_client.set(dedup_key, b"1", nx=True, ex=3600) + return was_set is None + except Exception: + # Fail open: if Redis is unavailable, process the event + # (prefer occasional duplicate over lost data) + logger.warning( + "Redis unavailable for deduplication check, processing event anyway: %s", + envelope.event_id, + exc_info=True, + ) + return False + + def _rehydrate(self, envelope: TelemetryEnvelope) -> dict[str, Any]: + """Rehydrate payload from storage reference or inline data. + + If the envelope payload is empty and metadata contains a + ``payload_ref``, the full payload is loaded from object storage + (where the gateway wrote it as JSON). When both the inline + payload and storage resolution fail, a degraded-event marker + is emitted so the gap is observable. + + Args: + envelope: The telemetry envelope containing payload data. + + Returns: + The rehydrated payload dictionary, or ``{}`` on total failure. + """ + payload = envelope.payload + + # Resolve from object storage when the gateway offloaded a large payload. + if not payload and envelope.metadata: + payload_ref = envelope.metadata.get("payload_ref") + if payload_ref: + try: + payload_bytes = storage.load(payload_ref) + payload = json.loads(payload_bytes.decode("utf-8")) + logger.debug("Loaded payload from storage: key=%s", payload_ref) + except Exception: + logger.warning( + "Failed to load payload from storage: key=%s, event_id=%s", + payload_ref, + envelope.event_id, + exc_info=True, + ) + + if not payload: + # Storage resolution failed or no data available — emit degraded event. + logger.error( + "Payload rehydration failed for event_id=%s, tenant_id=%s, case=%s", + envelope.event_id, + envelope.tenant_id, + envelope.case, + ) + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.REHYDRATION_FAILED, + attributes={ + "tenant_id": envelope.tenant_id, + "dify.telemetry.error": f"Payload rehydration failed for event_id={envelope.event_id}", + "dify.telemetry.payload_type": envelope.case, + "dify.telemetry.correlation_id": envelope.event_id, + }, + tenant_id=envelope.tenant_id, + ) + self._increment_diagnostic_counter("rehydration_failed_total") + return {} + + return payload + + # Stub methods for each metric/log case + # These will be implemented in later tasks with actual emission logic + + def _on_app_created(self, envelope: TelemetryEnvelope) -> None: + """Handle app created event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_CREATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app_id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.mode": payload.get("mode"), + "dify.app.created_at": datetime.now(UTC).isoformat(), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_CREATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_CREATED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + "mode": str(payload.get("mode", "")), + }, + ) + + def _on_app_updated(self, envelope: TelemetryEnvelope) -> None: + """Handle app updated event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_UPDATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app_id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.updated_at": datetime.now(UTC).isoformat(), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_UPDATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_UPDATED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + }, + ) + + def _on_app_deleted(self, envelope: TelemetryEnvelope) -> None: + """Handle app deleted event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_DELETED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app_id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.deleted_at": datetime.now(UTC).isoformat(), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_DELETED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_DELETED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + }, + ) + + def _on_feedback_created(self, envelope: TelemetryEnvelope) -> None: + """Handle feedback created event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for FEEDBACK_CREATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + include_content = exporter.include_content + attrs: dict = { + "dify.message.id": payload.get("message_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app_id": payload.get("app_id"), + "dify.conversation.id": payload.get("conversation_id"), + "gen_ai.user.id": payload.get("from_end_user_id") or payload.get("from_account_id"), + "dify.feedback.rating": payload.get("rating"), + "dify.feedback.from_source": payload.get("from_source"), + "dify.feedback.created_at": datetime.now(UTC).isoformat(), + } + if include_content: + attrs["dify.feedback.content"] = payload.get("content") + + user_id = payload.get("from_end_user_id") or payload.get("from_account_id") + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.FEEDBACK_CREATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + user_id=str(user_id or ""), + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.FEEDBACK, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + "rating": str(payload.get("rating", "")), + }, + ) + + def _on_message_run(self, envelope: TelemetryEnvelope) -> None: + """Handle message run event. + + Intentionally a no-op: metrics and structured logs for message runs are + emitted directly by EnterpriseOtelTrace._message_trace at trace time, + not through the metric handler queue path. + """ + logger.debug("Processing MESSAGE_RUN: event_id=%s", envelope.event_id) + + def _on_tool_execution(self, envelope: TelemetryEnvelope) -> None: + """Handle tool execution event. + + Intentionally a no-op: metrics and structured logs for tool executions + are emitted directly by EnterpriseOtelTrace._tool_trace at trace time, + not through the metric handler queue path. + """ + logger.debug("Processing TOOL_EXECUTION: event_id=%s", envelope.event_id) + + def _on_moderation_check(self, envelope: TelemetryEnvelope) -> None: + """Handle moderation check event. + + Intentionally a no-op: metrics and structured logs for moderation checks + are emitted directly by EnterpriseOtelTrace._moderation_trace at trace time, + not through the metric handler queue path. + """ + logger.debug("Processing MODERATION_CHECK: event_id=%s", envelope.event_id) + + def _on_suggested_question(self, envelope: TelemetryEnvelope) -> None: + """Handle suggested question event. + + Intentionally a no-op: metrics and structured logs for suggested questions + are emitted directly by EnterpriseOtelTrace._suggested_question_trace at + trace time, not through the metric handler queue path. + """ + logger.debug("Processing SUGGESTED_QUESTION: event_id=%s", envelope.event_id) + + def _on_dataset_retrieval(self, envelope: TelemetryEnvelope) -> None: + """Handle dataset retrieval event. + + Intentionally a no-op: metrics and structured logs for dataset retrievals + are emitted directly by EnterpriseOtelTrace._dataset_retrieval_trace at + trace time, not through the metric handler queue path. + """ + logger.debug("Processing DATASET_RETRIEVAL: event_id=%s", envelope.event_id) + + def _on_generate_name(self, envelope: TelemetryEnvelope) -> None: + """Handle generate name event. + + Intentionally a no-op: metrics and structured logs for generate name + operations are emitted directly by EnterpriseOtelTrace._generate_name_trace + at trace time, not through the metric handler queue path. + """ + logger.debug("Processing GENERATE_NAME: event_id=%s", envelope.event_id) + + def _on_prompt_generation(self, envelope: TelemetryEnvelope) -> None: + """Handle prompt generation event. + + Intentionally a no-op: metrics and structured logs for prompt generation + operations are emitted directly by EnterpriseOtelTrace._prompt_generation_trace + at trace time, not through the metric handler queue path. + """ + logger.debug("Processing PROMPT_GENERATION: event_id=%s", envelope.event_id) diff --git a/api/enterprise/telemetry/telemetry_log.py b/api/enterprise/telemetry/telemetry_log.py new file mode 100644 index 0000000000..8cce4a9fcd --- /dev/null +++ b/api/enterprise/telemetry/telemetry_log.py @@ -0,0 +1,122 @@ +"""Structured-log emitter for enterprise telemetry events. + +Emits structured JSON log lines correlated with OTEL traces via trace_id. +Picked up by ``StructuredJSONFormatter`` → stdout/Loki/Elastic. +""" + +from __future__ import annotations + +import logging +import uuid +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + +logger = logging.getLogger("dify.telemetry") + + +@lru_cache(maxsize=4096) +def compute_trace_id_hex(uuid_str: str | None) -> str: + """Convert a business UUID string to a 32-hex OTEL-compatible trace_id. + + Returns empty string when *uuid_str* is ``None`` or invalid. + """ + if not uuid_str: + return "" + normalized = uuid_str.strip().lower() + if len(normalized) == 32 and all(ch in "0123456789abcdef" for ch in normalized): + return normalized + try: + return f"{uuid.UUID(normalized).int:032x}" + except (ValueError, AttributeError): + return "" + + +@lru_cache(maxsize=4096) +def compute_span_id_hex(uuid_str: str | None) -> str: + if not uuid_str: + return "" + normalized = uuid_str.strip().lower() + if len(normalized) == 16 and all(ch in "0123456789abcdef" for ch in normalized): + return normalized + try: + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + return f"{compute_deterministic_span_id(normalized):016x}" + except (ValueError, AttributeError): + return "" + + +def emit_telemetry_log( + *, + event_name: str | EnterpriseTelemetryEvent, + attributes: dict[str, Any], + signal: str = "metric_only", + trace_id_source: str | None = None, + span_id_source: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, +) -> None: + """Emit a structured log line for a telemetry event. + + Parameters + ---------- + event_name: + Canonical event name, e.g. ``"dify.workflow.run"``. + attributes: + All event-specific attributes (already built by the caller). + signal: + ``"metric_only"`` for events with no span, ``"span_detail"`` + for detail logs accompanying a slim span. + trace_id_source: + A UUID string (e.g. ``workflow_run_id``) used to derive a 32-hex + trace_id for cross-signal correlation. + tenant_id: + Tenant identifier (for the ``IdentityContextFilter``). + user_id: + User identifier (for the ``IdentityContextFilter``). + """ + if not logger.isEnabledFor(logging.INFO): + return + attrs = { + "dify.event.name": event_name, + "dify.event.signal": signal, + **attributes, + } + + extra: dict[str, Any] = {"attributes": attrs} + + trace_id_hex = compute_trace_id_hex(trace_id_source) + if trace_id_hex: + extra["trace_id"] = trace_id_hex + span_id_hex = compute_span_id_hex(span_id_source) + if span_id_hex: + extra["span_id"] = span_id_hex + if tenant_id: + extra["tenant_id"] = tenant_id + if user_id: + extra["user_id"] = user_id + + logger.info("telemetry.%s", signal, extra=extra) + + +def emit_metric_only_event( + *, + event_name: str | EnterpriseTelemetryEvent, + attributes: dict[str, Any], + trace_id_source: str | None = None, + span_id_source: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, +) -> None: + emit_telemetry_log( + event_name=event_name, + attributes=attributes, + signal="metric_only", + trace_id_source=trace_id_source, + span_id_source=span_id_source, + tenant_id=tenant_id, + user_id=user_id, + ) diff --git a/api/events/app_event.py b/api/events/app_event.py index f2ce71bbbb..2fba0028f9 100644 --- a/api/events/app_event.py +++ b/api/events/app_event.py @@ -11,3 +11,9 @@ app_published_workflow_was_updated = signal("app-published-workflow-was-updated" # sender: app, kwargs: synced_draft_workflow app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced") + +# sender: app +app_was_updated = signal("app-was-updated") + +# sender: app +app_was_deleted = signal("app-was-deleted") diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index 5e7caf8cbe..84be592b1a 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -1,5 +1,6 @@ from events.app_event import app_was_created from extensions.ext_database import db +from models.enums import CustomizeTokenStrategy from models.model import Site @@ -16,7 +17,7 @@ def handle(sender, **kwargs): icon=app.icon, icon_background=app.icon_background, default_language=account.interface_language, - customize_token_strategy="not_allow", + customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW, code=Site.generate_code(16), created_by=app.created_by, updated_by=app.updated_by, diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index ba9758175f..7bd8e88231 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,11 +1,12 @@ import logging +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.tool.entities import ToolEntity + from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_draft_workflow_was_synced -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.tool.entities import ToolEntity logger = logging.getLogger(__name__) diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 6769b94cde..86b5b2bbf0 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,11 +1,11 @@ from typing import cast +from graphon.nodes import BuiltinNodeTypes from sqlalchemy import delete, select from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db -from graphon.nodes import BuiltinNodeTypes from models.dataset import AppDatasetJoin from models.workflow import Workflow diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 367a4c1ede..4eed34436a 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -204,6 +204,8 @@ def init_app(app: DifyApp) -> Celery: "schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL), } + if dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED: + imports.append("tasks.enterprise_telemetry_task") celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/extensions/ext_enterprise_telemetry.py b/api/extensions/ext_enterprise_telemetry.py new file mode 100644 index 0000000000..b3cfa01aee --- /dev/null +++ b/api/extensions/ext_enterprise_telemetry.py @@ -0,0 +1,50 @@ +"""Flask extension for enterprise telemetry lifecycle management. + +Initializes the EnterpriseExporter singleton during ``create_app()`` +(single-threaded), registers blinker event handlers, and hooks atexit +for graceful shutdown. + +Skipped entirely when either ``ENTERPRISE_ENABLED`` or ``ENTERPRISE_TELEMETRY_ENABLED`` +is false (``is_enabled()`` gate). +""" + +from __future__ import annotations + +import atexit +import logging +from typing import TYPE_CHECKING + +from configs import dify_config + +if TYPE_CHECKING: + from dify_app import DifyApp + from enterprise.telemetry.exporter import EnterpriseExporter + +logger = logging.getLogger(__name__) + +_exporter: EnterpriseExporter | None = None + + +def is_enabled() -> bool: + return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED) + + +def init_app(app: DifyApp) -> None: + global _exporter + + if not is_enabled(): + return + + from enterprise.telemetry.exporter import EnterpriseExporter + + _exporter = EnterpriseExporter(dify_config) + atexit.register(_exporter.shutdown) + + # Import to trigger @signal.connect decorator registration + import enterprise.telemetry.event_handlers # noqa: F401 # type: ignore[reportUnusedImport] + + logger.info("Enterprise telemetry initialized") + + +def get_enterprise_exporter() -> EnterpriseExporter | None: + return _exporter diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index a5baa21018..63edbe93e7 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -78,16 +78,24 @@ def init_app(app: DifyApp): protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower() if dify_config.OTEL_EXPORTER_TYPE == "otlp": if protocol == "grpc": + # Auto-detect TLS: https:// uses secure, everything else is insecure + endpoint = dify_config.OTLP_BASE_ENDPOINT + insecure = not endpoint.startswith("https://") + + # Header field names must consist of lowercase letters, check RFC7540 + grpc_headers = ( + (("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else () + ) + exporter = GRPCSpanExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT, - # Header field names must consist of lowercase letters, check RFC7540 - headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), - insecure=True, + endpoint=endpoint, + headers=grpc_headers, + insecure=insecure, ) metric_exporter = GRPCMetricExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT, - headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), - insecure=True, + endpoint=endpoint, + headers=grpc_headers, + insecure=insecure, ) else: headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 120febecfb..651f8ed898 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,13 +5,12 @@ from dify_app import DifyApp def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import sentry_sdk + from graphon.model_runtime.errors.invoke import InvokeRateLimitError from langfuse import parse_error from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException - from graphon.model_runtime.errors.invoke import InvokeRateLimitError - def before_send(event, hint): if "exc_info" in hint: _, exc_value, _ = hint["exc_info"] diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index bdfa984874..db599c5d49 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -11,12 +11,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value -from graphon.enums import WorkflowNodeExecutionStatus from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository @@ -60,7 +60,7 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode model.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN model.node_id = data.get("node_id") or "" model.node_type = data.get("node_type") or "" - model.status = data.get("status") or "running" # Default status if missing + model.status = WorkflowNodeExecutionStatus(data.get("status") or "running") model.title = data.get("title") or "" created_by_role_val = data.get("created_by_role") try: diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 5208f8f37e..3c83ab4f84 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -20,12 +20,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, cast +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string -from graphon.enums import WorkflowExecutionStatus from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index ea4a2b3dd1..f71b2fa1df 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -4,14 +4,14 @@ import os import time from typing import Union +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from extensions.logstore.aliyun_logstore import AliyunLogStore -from graphon.entities import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index 976b5db8e3..b725436681 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -13,6 +13,10 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, Union +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -21,10 +25,6 @@ from core.repositories.factory import OrderConfig, WorkflowNodeExecutionReposito from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier -from graphon.entities import WorkflowNodeExecution -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/otel/parser/__init__.py b/api/extensions/otel/parser/__init__.py index 164db7c275..c671e8b409 100644 --- a/api/extensions/otel/parser/__init__.py +++ b/api/extensions/otel/parser/__init__.py @@ -5,7 +5,7 @@ This module provides parsers that extract node-specific metadata and set OpenTelemetry span attributes according to semantic conventions. """ -from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps +from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps, should_include_content from extensions.otel.parser.llm import LLMNodeOTelParser from extensions.otel.parser.retrieval import RetrievalNodeOTelParser from extensions.otel.parser.tool import ToolNodeOTelParser @@ -17,4 +17,5 @@ __all__ = [ "RetrievalNodeOTelParser", "ToolNodeOTelParser", "safe_json_dumps", + "should_include_content", ] diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index a2f552cac1..23d324f9ea 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -1,20 +1,36 @@ """ Base parser interface and utilities for OpenTelemetry node parsers. + +Content gating: ``should_include_content()`` controls whether content-bearing +span attributes (inputs, outputs, prompts, completions, documents) are written. +Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when +``ENTERPRISE_INCLUDE_CONTENT=False``; CE behaviour is unchanged. """ import json from typing import Any, Protocol +from graphon.enums import BuiltinNodeTypes +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel +from configs import dify_config from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes -from graphon.enums import BuiltinNodeTypes -from graphon.file.models import File -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment + + +def should_include_content() -> bool: + """Return True if content should be written to spans. + + CE (ENTERPRISE_ENABLED=False): always True — no behaviour change. + """ + if not dify_config.ENTERPRISE_ENABLED: + return True + return dify_config.ENTERPRISE_INCLUDE_CONTENT def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str: @@ -101,10 +117,11 @@ class DefaultNodeOTelParser: # Extract inputs and outputs from result_event if result_event and result_event.node_run_result: node_run_result = result_event.node_run_result - if node_run_result.inputs: - span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs)) - if node_run_result.outputs: - span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs)) + if should_include_content(): + if node_run_result.inputs: + span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs)) + if node_run_result.outputs: + span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs)) if error: span.record_exception(error) diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py index ec3c78a12d..335c5cc29e 100644 --- a/api/extensions/otel/parser/llm.py +++ b/api/extensions/otel/parser/llm.py @@ -6,12 +6,12 @@ import logging from collections.abc import Mapping from typing import Any +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import LLMAttributes -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index 56672d1fd4..6df5f62c15 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -6,13 +6,13 @@ import logging from collections.abc import Sequence from typing import Any +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import RetrieverAttributes -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py index 75ddbba448..b9fdd9e1ca 100644 --- a/api/extensions/otel/parser/tool.py +++ b/api/extensions/otel/parser/tool.py @@ -2,14 +2,14 @@ Parser for tool nodes that captures tool-specific metadata. """ -from opentelemetry.trace import Span - -from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps -from extensions.otel.semconv.gen_ai import ToolAttributes from graphon.enums import WorkflowNodeExecutionMetadataKey from graphon.graph_events import GraphNodeEventBase from graphon.nodes.base.node import Node from graphon.nodes.tool.entities import ToolNodeData +from opentelemetry.trace import Span + +from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps +from extensions.otel.semconv.gen_ai import ToolAttributes class ToolNodeOTelParser: diff --git a/api/extensions/otel/semconv/dify.py b/api/extensions/otel/semconv/dify.py index a20b9b358d..301ddd11aa 100644 --- a/api/extensions/otel/semconv/dify.py +++ b/api/extensions/otel/semconv/dify.py @@ -21,3 +21,15 @@ class DifySpanAttributes: INVOKE_FROM = "dify.invoke_from" """Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER.""" + + INVOKED_BY = "dify.invoked_by" + """Invoked by, e.g. end_user, account, user.""" + + USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" + """Number of input tokens (prompt tokens) used.""" + + USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" + """Number of output tokens (completion tokens) generated.""" + + USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens" + """Total number of tokens used.""" diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py index bc87510d43..7516d18c8e 100644 --- a/api/factories/file_factory/builders.py +++ b/api/factories/file_factory/builders.py @@ -7,13 +7,12 @@ import uuid from collections.abc import Mapping, Sequence from typing import Any +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type from sqlalchemy import select from core.app.file_access import FileAccessControllerProtocol from core.workflow.file_reference import build_file_reference from extensions.ext_database import db -from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers -from graphon.file.file_factory import standardize_file_type from models import ToolFile, UploadFile from .common import resolve_mapping_file_id diff --git a/api/factories/file_factory/message_files.py b/api/factories/file_factory/message_files.py index 4b3d514238..5582b85c95 100644 --- a/api/factories/file_factory/message_files.py +++ b/api/factories/file_factory/message_files.py @@ -4,8 +4,9 @@ from __future__ import annotations from collections.abc import Sequence -from core.app.file_access import FileAccessControllerProtocol from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig + +from core.app.file_access import FileAccessControllerProtocol from models import MessageFile from .builders import build_from_mapping diff --git a/api/factories/file_factory/storage_keys.py b/api/factories/file_factory/storage_keys.py index dba4c84407..db3a7f3015 100644 --- a/api/factories/file_factory/storage_keys.py +++ b/api/factories/file_factory/storage_keys.py @@ -5,12 +5,12 @@ from __future__ import annotations import uuid from collections.abc import Mapping, Sequence +from graphon.file import File, FileTransferMethod from sqlalchemy import select from sqlalchemy.orm import Session from core.app.file_access import FileAccessControllerProtocol from core.workflow.file_reference import build_file_reference, parse_file_reference -from graphon.file import File, FileTransferMethod from models import ToolFile, UploadFile diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index fd7acb14d3..57205b5739 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -8,11 +8,6 @@ shared conversion functions for legacy callers and tests. from collections.abc import Mapping, Sequence from typing import Any, cast -from configs import dify_config -from core.workflow.variable_prefixes import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, -) from graphon.variables.exc import VariableError from graphon.variables.factory import ( TypeMismatchError, @@ -36,6 +31,12 @@ from graphon.variables.variables import ( VariableBase, ) +from configs import dify_config +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) + __all__ = [ "TypeMismatchError", "UnsupportedSegmentTypeError", diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 801949747e..30d02aeedc 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -3,9 +3,8 @@ from __future__ import annotations from datetime import datetime from typing import Any, TypeAlias -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator - from graphon.file import File +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator JSONValue: TypeAlias = Any diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 4e201e66e6..b8daa5af30 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -3,9 +3,8 @@ from __future__ import annotations from datetime import datetime from flask_restx import fields -from pydantic import BaseModel, ConfigDict, computed_field, field_validator - from graphon.file import helpers as file_helpers +from pydantic import BaseModel, ConfigDict, computed_field, field_validator simple_account_fields = { "id": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 86c4f285cd..d982c31aee 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -4,11 +4,11 @@ from datetime import datetime from typing import TypeAlias from uuid import uuid4 +from graphon.file import File from pydantic import BaseModel, ConfigDict, Field, field_validator from core.entities.execution_extra_content import ExecutionExtraContentDomainModel from fields.conversation_fields import AgentThought, JSONValue, MessageFile -from graphon.file import File JSONValueType: TypeAlias = JSONValue diff --git a/api/fields/raws.py b/api/fields/raws.py index ee6f53b360..4c65cdab7a 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,5 +1,4 @@ from flask_restx import fields - from graphon.file import File diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index f9b5e98936..b0b6cc0b48 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,8 +1,8 @@ from flask_restx import fields +from graphon.variables import SecretVariable, SegmentType, VariableBase from core.helper import encrypter from fields.member_fields import simple_account_fields -from graphon.variables import SecretVariable, SegmentType, VariableBase from libs.helper import TimestampField from ._value_type_serializer import serialize_value_type diff --git a/api/graphon/README.md b/api/graphon/README.md deleted file mode 100644 index 725f122cd8..0000000000 --- a/api/graphon/README.md +++ /dev/null @@ -1,135 +0,0 @@ -# Workflow - -## Project Overview - -This is the workflow graph engine module of Dify, implementing a queue-based distributed workflow execution system. The engine handles agentic AI workflows with support for parallel execution, node iteration, conditional logic, and external command control. - -## Architecture - -### Core Components - -The graph engine follows a layered architecture with strict dependency rules: - -1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution - - - **Manager** - External control interface for stop/pause/resume commands - - **Worker** - Node execution runtime - - **Command Processing** - Handles control commands (abort, pause, resume) - - **Event Management** - Event propagation and layer notifications - - **Graph Traversal** - Edge processing and skip propagation - - **Response Coordinator** - Path tracking and session management - - **Layers** - Pluggable middleware (debug logging, execution limits) - - **Command Channels** - Communication channels (InMemory, Redis) - -1. **Graph** (`graph/`) - Graph structure and runtime state - - - **Graph Template** - Workflow definition - - **Edge** - Node connections with conditions - - **Runtime State Protocol** - State management interface - -1. **Nodes** (`nodes/`) - Node implementations - - - **Base** - Abstract node classes and variable parsing - - **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc. - -1. **Events** (`node_events/`) - Event system - - - **Base** - Event protocols - - **Node Events** - Node lifecycle events - -1. **Entities** (`entities/`) - Domain models - - - **Variable Pool** - Variable storage - - **Graph Init Params** - Initialization configuration - -## Key Design Patterns - -### Command Channel Pattern - -External workflow control via Redis or in-memory channels: - -```python -# Send stop command to running workflow -channel = RedisChannel(redis_client, f"workflow:{task_id}:commands") -channel.send_command(AbortCommand(reason="User requested")) -``` - -### Layer System - -Extensible middleware for cross-cutting concerns: - -```python -engine = GraphEngine(graph) -engine.layer(DebugLoggingLayer(level="INFO")) -engine.layer(ExecutionLimitsLayer(max_nodes=100)) -``` - -`engine.layer()` binds the read-only runtime state before execution, so layer hooks -can assume `graph_runtime_state` is available. - -### Event-Driven Architecture - -All node executions emit events for monitoring and integration: - -- `NodeRunStartedEvent` - Node execution begins -- `NodeRunSucceededEvent` - Node completes successfully -- `NodeRunFailedEvent` - Node encounters error -- `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle - -### Variable Pool - -Centralized variable storage with namespace isolation: - -```python -# Variables scoped by node_id -pool.add(["node1", "output"], value) -result = pool.get(["node1", "output"]) -``` - -## Import Architecture Rules - -The codebase enforces strict layering via import-linter: - -1. **Workflow Layers** (top to bottom): - - - graph_engine → graph_events → graph → nodes → node_events → entities - -1. **Graph Engine Internal Layers**: - - - orchestration → command_processing → event_management → graph_traversal → domain - -1. **Domain Isolation**: - - - Domain models cannot import from infrastructure layers - -1. **Command Channel Independence**: - - - InMemory and Redis channels must remain independent - -## Common Tasks - -### Adding a New Node Type - -1. Create node class in `nodes//` -1. Inherit from `BaseNode` or appropriate base class -1. Implement `_run()` method -1. Ensure the node module is importable under `nodes//` -1. Add tests in `tests/unit_tests/graphon/nodes/` - -### Implementing a Custom Layer - -1. Create class inheriting from `Layer` base -1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()` -1. Add to engine via `engine.layer()` - -### Debugging Workflow Execution - -Enable debug logging layer: - -```python -debug_layer = DebugLoggingLayer( - level="DEBUG", - include_inputs=True, - include_outputs=True -) -``` diff --git a/api/graphon/entities/__init__.py b/api/graphon/entities/__init__.py deleted file mode 100644 index ef7789c49c..0000000000 --- a/api/graphon/entities/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .graph_init_params import GraphInitParams -from .workflow_execution import WorkflowExecution -from .workflow_node_execution import WorkflowNodeExecution -from .workflow_start_reason import WorkflowStartReason - -__all__ = [ - "GraphInitParams", - "WorkflowExecution", - "WorkflowNodeExecution", - "WorkflowStartReason", -] diff --git a/api/graphon/entities/base_node_data.py b/api/graphon/entities/base_node_data.py deleted file mode 100644 index e8267043a9..0000000000 --- a/api/graphon/entities/base_node_data.py +++ /dev/null @@ -1,178 +0,0 @@ -from __future__ import annotations - -import json -from abc import ABC -from builtins import type as type_ -from enum import StrEnum -from typing import Any, Union - -from pydantic import BaseModel, ConfigDict, Field, model_validator - -from graphon.entities.exc import DefaultValueTypeError -from graphon.enums import ErrorStrategy, NodeType - -# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`. -_NumberType = Union[int, float] - - -class RetryConfig(BaseModel): - """node retry config""" - - max_retries: int = 0 # max retry times - retry_interval: int = 0 # retry interval in milliseconds - retry_enabled: bool = False # whether retry is enabled - - @property - def retry_interval_seconds(self) -> float: - return self.retry_interval / 1000 - - -class DefaultValueType(StrEnum): - STRING = "string" - NUMBER = "number" - OBJECT = "object" - ARRAY_NUMBER = "array[number]" - ARRAY_STRING = "array[string]" - ARRAY_OBJECT = "array[object]" - ARRAY_FILES = "array[file]" - - -class DefaultValue(BaseModel): - value: Any = None - type: DefaultValueType - key: str - - @staticmethod - def _parse_json(value: str): - """Unified JSON parsing handler""" - try: - return json.loads(value) - except json.JSONDecodeError: - raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") - - @staticmethod - def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: - """Unified array type validation""" - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) - - @staticmethod - def _convert_number(value: str) -> float: - """Unified number conversion handler""" - try: - return float(value) - except ValueError: - raise DefaultValueTypeError(f"Cannot convert to number: {value}") - - @model_validator(mode="after") - def validate_value_type(self) -> DefaultValue: - # Type validation configuration - type_validators: dict[DefaultValueType, dict[str, Any]] = { - DefaultValueType.STRING: { - "type": str, - "converter": lambda x: x, - }, - DefaultValueType.NUMBER: { - "type": _NumberType, - "converter": self._convert_number, - }, - DefaultValueType.OBJECT: { - "type": dict, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_NUMBER: { - "type": list, - "element_type": _NumberType, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_STRING: { - "type": list, - "element_type": str, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_OBJECT: { - "type": list, - "element_type": dict, - "converter": self._parse_json, - }, - } - - validator: dict[str, Any] = type_validators.get(self.type, {}) - if not validator: - if self.type == DefaultValueType.ARRAY_FILES: - # Handle files type - return self - raise DefaultValueTypeError(f"Unsupported type: {self.type}") - - # Handle string input cases - if isinstance(self.value, str) and self.type != DefaultValueType.STRING: - self.value = validator["converter"](self.value) - - # Validate base type - if not isinstance(self.value, validator["type"]): - raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") - - # Validate array element types - if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): - raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") - - return self - - -class BaseNodeData(ABC, BaseModel): - # Raw graph payloads are first validated through `NodeConfigDictAdapter`, where - # `node["data"]` is typed as `BaseNodeData` before the concrete node class is known. - # `type` therefore accepts downstream string node kinds; unknown node implementations - # are rejected later when the node factory resolves the node registry. - # At that boundary, node-specific fields are still "extra" relative to this shared DTO, - # and persisted templates/workflows also carry undeclared compatibility keys such as - # `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive - # here until graph parsing becomes discriminated by node type or those legacy payloads - # are normalized. - model_config = ConfigDict(extra="allow") - - type: NodeType - title: str = "" - desc: str | None = None - version: str = "1" - error_strategy: ErrorStrategy | None = None - default_value: list[DefaultValue] | None = None - retry_config: RetryConfig = Field(default_factory=RetryConfig) - - @property - def default_value_dict(self) -> dict[str, Any]: - if self.default_value: - return {item.key: item.value for item in self.default_value} - return {} - - def __getitem__(self, key: str) -> Any: - """ - Dict-style access without calling model_dump() on every lookup. - Prefer using model fields and Pydantic's extra storage. - """ - # First, check declared model fields - if key in self.__class__.model_fields: - return getattr(self, key) - - # Then, check undeclared compatibility fields stored in Pydantic's extra dict. - extras = getattr(self, "__pydantic_extra__", None) - if extras is None: - extras = getattr(self, "model_extra", None) - if extras is not None and key in extras: - return extras[key] - - raise KeyError(key) - - def get(self, key: str, default: Any = None) -> Any: - """ - Dict-style .get() without calling model_dump() on every lookup. - """ - if key in self.__class__.model_fields: - return getattr(self, key) - - extras = getattr(self, "__pydantic_extra__", None) - if extras is None: - extras = getattr(self, "model_extra", None) - if extras is not None and key in extras: - return extras.get(key, default) - - return default diff --git a/api/graphon/entities/exc.py b/api/graphon/entities/exc.py deleted file mode 100644 index aeecf40640..0000000000 --- a/api/graphon/entities/exc.py +++ /dev/null @@ -1,10 +0,0 @@ -class BaseNodeError(ValueError): - """Base class for node errors.""" - - pass - - -class DefaultValueTypeError(BaseNodeError): - """Raised when the default value type is invalid.""" - - pass diff --git a/api/graphon/entities/graph_config.py b/api/graphon/entities/graph_config.py deleted file mode 100644 index 392241c631..0000000000 --- a/api/graphon/entities/graph_config.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -import sys - -from pydantic import TypeAdapter, with_config - -from graphon.entities.base_node_data import BaseNodeData - -if sys.version_info >= (3, 12): - from typing import TypedDict -else: - from typing_extensions import TypedDict - - -@with_config(extra="allow") -class NodeConfigDict(TypedDict): - id: str - # This is the permissive raw graph boundary. Node factories re-validate `data` - # with the concrete `NodeData` subtype after resolving the node implementation. - data: BaseNodeData - - -NodeConfigDictAdapter = TypeAdapter(NodeConfigDict) diff --git a/api/graphon/entities/graph_init_params.py b/api/graphon/entities/graph_init_params.py deleted file mode 100644 index f785d58a52..0000000000 --- a/api/graphon/entities/graph_init_params.py +++ /dev/null @@ -1,24 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import BaseModel, Field - -DIFY_RUN_CONTEXT_KEY = "_dify" - - -class GraphInitParams(BaseModel): - """GraphInitParams encapsulates the configurations and contextual information - that remain constant throughout a single execution of the graph engine. - - A single execution is defined as follows: as long as the execution has not reached - its conclusion, it is considered one execution. For instance, if a workflow is suspended - and later resumed, it is still regarded as a single execution, not two. - - For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`. - """ - - # init params - workflow_id: str = Field(..., description="workflow id") - graph_config: Mapping[str, Any] = Field(..., description="graph config") - run_context: Mapping[str, Any] = Field(..., description="runtime context") - call_depth: int = Field(..., description="call depth") diff --git a/api/graphon/entities/pause_reason.py b/api/graphon/entities/pause_reason.py deleted file mode 100644 index ba2973fd45..0000000000 --- a/api/graphon/entities/pause_reason.py +++ /dev/null @@ -1,42 +0,0 @@ -from collections.abc import Mapping -from enum import StrEnum, auto -from typing import Annotated, Any, Literal, TypeAlias - -from pydantic import BaseModel, Field - -from graphon.nodes.human_input.entities import FormInput, UserAction - - -class PauseReasonType(StrEnum): - HUMAN_INPUT_REQUIRED = auto() - SCHEDULED_PAUSE = auto() - - -class HumanInputRequired(BaseModel): - TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED - form_id: str - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - actions: list[UserAction] = Field(default_factory=list) - node_id: str - node_title: str - - # The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from - # `output_variable_name` to their resolved values. - # - # For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its - # selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable - # `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The - # `resolved_default_values` is `{"name": "John"}`. - # - # Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`. - resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) - - -class SchedulingPause(BaseModel): - TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE - - message: str - - -PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")] diff --git a/api/graphon/entities/workflow_execution.py b/api/graphon/entities/workflow_execution.py deleted file mode 100644 index b8de7eed1a..0000000000 --- a/api/graphon/entities/workflow_execution.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -Domain entities for workflow execution. - -Models describe graph runtime state and avoid infrastructure-specific details. -""" - -from __future__ import annotations - -from collections.abc import Mapping -from datetime import UTC, datetime -from typing import Any - -from pydantic import BaseModel, Field - -from graphon.enums import WorkflowExecutionStatus, WorkflowType - - -class WorkflowExecution(BaseModel): - """ - Domain model for a workflow execution within the graph runtime. - """ - - id_: str = Field(...) - workflow_id: str = Field(...) - workflow_version: str = Field(...) - workflow_type: WorkflowType = Field(...) - graph: Mapping[str, Any] = Field(...) - - inputs: Mapping[str, Any] = Field(...) - outputs: Mapping[str, Any] | None = None - - status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING - error_message: str = Field(default="") - total_tokens: int = Field(default=0) - total_steps: int = Field(default=0) - exceptions_count: int = Field(default=0) - - started_at: datetime = Field(...) - finished_at: datetime | None = None - - @property - def elapsed_time(self) -> float: - """ - Calculate elapsed time in seconds. - If workflow is not finished, use current time. - """ - end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None) - return (end_time - self.started_at).total_seconds() - - @classmethod - def new( - cls, - *, - id_: str, - workflow_id: str, - workflow_type: WorkflowType, - workflow_version: str, - graph: Mapping[str, Any], - inputs: Mapping[str, Any], - started_at: datetime, - ) -> WorkflowExecution: - return WorkflowExecution( - id_=id_, - workflow_id=workflow_id, - workflow_type=workflow_type, - workflow_version=workflow_version, - graph=graph, - inputs=inputs, - status=WorkflowExecutionStatus.RUNNING, - started_at=started_at, - ) diff --git a/api/graphon/entities/workflow_node_execution.py b/api/graphon/entities/workflow_node_execution.py deleted file mode 100644 index 5458572e7e..0000000000 --- a/api/graphon/entities/workflow_node_execution.py +++ /dev/null @@ -1,141 +0,0 @@ -""" -Domain entities for workflow node execution. - -These models capture node-level execution state for the graph runtime without -describing storage or application-layer concerns. -""" - -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, Field, PrivateAttr - -from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus - - -class WorkflowNodeExecution(BaseModel): - """ - Domain model for workflow node execution. - - This model represents the graph-level record of a node execution and - contains only execution state relevant to the runtime. - """ - - # --------- Core identification fields --------- - - # Unique identifier for this execution record, used when persisting to storage. - # Value is a UUID string (e.g., '09b3e04c-f9ae-404c-ad82-290b8d7bd382'). - id: str - - # Optional secondary ID for cross-referencing purposes. - # - # NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`. - # While `node_execution_id` may sometimes be a UUID string, this is not guaranteed. - # In most scenarios, `id` should be used as the primary identifier. - node_execution_id: str | None = None - workflow_id: str # ID of the workflow this node belongs to - workflow_execution_id: str | None = None # ID of the workflow execution (null for single-step debugging) - # --------- Core identification fields ends --------- - - # Execution positioning and flow - index: int # Sequence number for ordering in trace visualization - predecessor_node_id: str | None = None # ID of the node that executed before this one - node_id: str # ID of the node being executed - node_type: NodeType # Type of node (e.g., start, llm, downstream response node) - title: str # Display title of the node - - # Execution data - # The `inputs` and `outputs` fields hold the full content - inputs: Mapping[str, Any] | None = None # Input variables used by this node - process_data: Mapping[str, Any] | None = None # Intermediate processing data - outputs: Mapping[str, Any] | None = None # Output variables produced by this node - - # Execution state - status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status - error: str | None = None # Error message if execution failed - elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds - - # Additional metadata - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.) - - # Timing information - created_at: datetime # When execution started - finished_at: datetime | None = None # When execution completed - - _truncated_inputs: Mapping[str, Any] | None = PrivateAttr(None) - _truncated_outputs: Mapping[str, Any] | None = PrivateAttr(None) - _truncated_process_data: Mapping[str, Any] | None = PrivateAttr(None) - - def get_truncated_inputs(self) -> Mapping[str, Any] | None: - return self._truncated_inputs - - def get_truncated_outputs(self) -> Mapping[str, Any] | None: - return self._truncated_outputs - - def get_truncated_process_data(self) -> Mapping[str, Any] | None: - return self._truncated_process_data - - def set_truncated_inputs(self, truncated_inputs: Mapping[str, Any] | None): - self._truncated_inputs = truncated_inputs - - def set_truncated_outputs(self, truncated_outputs: Mapping[str, Any] | None): - self._truncated_outputs = truncated_outputs - - def set_truncated_process_data(self, truncated_process_data: Mapping[str, Any] | None): - self._truncated_process_data = truncated_process_data - - def get_response_inputs(self) -> Mapping[str, Any] | None: - inputs = self.get_truncated_inputs() - if inputs: - return inputs - return self.inputs - - @property - def inputs_truncated(self): - return self._truncated_inputs is not None - - @property - def outputs_truncated(self): - return self._truncated_outputs is not None - - @property - def process_data_truncated(self): - return self._truncated_process_data is not None - - def get_response_outputs(self) -> Mapping[str, Any] | None: - outputs = self.get_truncated_outputs() - if outputs is not None: - return outputs - return self.outputs - - def get_response_process_data(self) -> Mapping[str, Any] | None: - process_data = self.get_truncated_process_data() - if process_data is not None: - return process_data - return self.process_data - - def update_from_mapping( - self, - inputs: Mapping[str, Any] | None = None, - process_data: Mapping[str, Any] | None = None, - outputs: Mapping[str, Any] | None = None, - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None, - ): - """ - Update the model from mappings. - - Args: - inputs: The inputs to update - process_data: The process data to update - outputs: The outputs to update - metadata: The metadata to update - """ - if inputs is not None: - self.inputs = dict(inputs) - if process_data is not None: - self.process_data = dict(process_data) - if outputs is not None: - self.outputs = dict(outputs) - if metadata is not None: - self.metadata = dict(metadata) diff --git a/api/graphon/entities/workflow_start_reason.py b/api/graphon/entities/workflow_start_reason.py deleted file mode 100644 index df0f75383b..0000000000 --- a/api/graphon/entities/workflow_start_reason.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import StrEnum - - -class WorkflowStartReason(StrEnum): - """Reason for workflow start events across graph/queue/SSE layers.""" - - INITIAL = "initial" # First start of a workflow run. - RESUMPTION = "resumption" # Start triggered after resuming a paused run. diff --git a/api/graphon/enums.py b/api/graphon/enums.py deleted file mode 100644 index bbc973abe5..0000000000 --- a/api/graphon/enums.py +++ /dev/null @@ -1,262 +0,0 @@ -from enum import StrEnum -from typing import ClassVar, TypeAlias - - -class NodeState(StrEnum): - """State of a node or edge during workflow execution.""" - - UNKNOWN = "unknown" - TAKEN = "taken" - SKIPPED = "skipped" - - -NodeType: TypeAlias = str - - -class BuiltinNodeTypes: - """Built-in node type string constants. - - `node_type` values are plain strings throughout the graph runtime. This namespace - only exposes the built-in values shipped by `graphon`; downstream packages can - use additional strings without extending this class. - """ - - START: ClassVar[NodeType] = "start" - END: ClassVar[NodeType] = "end" - ANSWER: ClassVar[NodeType] = "answer" - LLM: ClassVar[NodeType] = "llm" - KNOWLEDGE_RETRIEVAL: ClassVar[NodeType] = "knowledge-retrieval" - IF_ELSE: ClassVar[NodeType] = "if-else" - CODE: ClassVar[NodeType] = "code" - TEMPLATE_TRANSFORM: ClassVar[NodeType] = "template-transform" - QUESTION_CLASSIFIER: ClassVar[NodeType] = "question-classifier" - HTTP_REQUEST: ClassVar[NodeType] = "http-request" - TOOL: ClassVar[NodeType] = "tool" - DATASOURCE: ClassVar[NodeType] = "datasource" - VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-aggregator" - LEGACY_VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-assigner" - LOOP: ClassVar[NodeType] = "loop" - LOOP_START: ClassVar[NodeType] = "loop-start" - LOOP_END: ClassVar[NodeType] = "loop-end" - ITERATION: ClassVar[NodeType] = "iteration" - ITERATION_START: ClassVar[NodeType] = "iteration-start" - PARAMETER_EXTRACTOR: ClassVar[NodeType] = "parameter-extractor" - VARIABLE_ASSIGNER: ClassVar[NodeType] = "assigner" - DOCUMENT_EXTRACTOR: ClassVar[NodeType] = "document-extractor" - LIST_OPERATOR: ClassVar[NodeType] = "list-operator" - AGENT: ClassVar[NodeType] = "agent" - HUMAN_INPUT: ClassVar[NodeType] = "human-input" - - -BUILT_IN_NODE_TYPES: tuple[NodeType, ...] = ( - BuiltinNodeTypes.START, - BuiltinNodeTypes.END, - BuiltinNodeTypes.ANSWER, - BuiltinNodeTypes.LLM, - BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, - BuiltinNodeTypes.IF_ELSE, - BuiltinNodeTypes.CODE, - BuiltinNodeTypes.TEMPLATE_TRANSFORM, - BuiltinNodeTypes.QUESTION_CLASSIFIER, - BuiltinNodeTypes.HTTP_REQUEST, - BuiltinNodeTypes.TOOL, - BuiltinNodeTypes.DATASOURCE, - BuiltinNodeTypes.VARIABLE_AGGREGATOR, - BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR, - BuiltinNodeTypes.LOOP, - BuiltinNodeTypes.LOOP_START, - BuiltinNodeTypes.LOOP_END, - BuiltinNodeTypes.ITERATION, - BuiltinNodeTypes.ITERATION_START, - BuiltinNodeTypes.PARAMETER_EXTRACTOR, - BuiltinNodeTypes.VARIABLE_ASSIGNER, - BuiltinNodeTypes.DOCUMENT_EXTRACTOR, - BuiltinNodeTypes.LIST_OPERATOR, - BuiltinNodeTypes.AGENT, - BuiltinNodeTypes.HUMAN_INPUT, -) - - -class NodeExecutionType(StrEnum): - """Node execution type classification.""" - - EXECUTABLE = "executable" # Regular nodes that execute and produce outputs - RESPONSE = "response" # Response nodes that stream outputs (Answer, End) - BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier) - CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph) - ROOT = "root" # Nodes that can serve as execution entry points - - -class ErrorStrategy(StrEnum): - FAIL_BRANCH = "fail-branch" - DEFAULT_VALUE = "default-value" - - -class FailBranchSourceHandle(StrEnum): - FAILED = "fail-branch" - SUCCESS = "success-branch" - - -class WorkflowType(StrEnum): - """ - Workflow Type Enum for domain layer - """ - - WORKFLOW = "workflow" - CHAT = "chat" - RAG_PIPELINE = "rag-pipeline" - - -class WorkflowExecutionStatus(StrEnum): - # State diagram for the workflw status: - # (@) means start, (*) means end - # - # ┌------------------>------------------------->------------------->--------------┐ - # | | - # | ┌-----------------------<--------------------┐ | - # ^ | | | - # | | ^ | - # | V | | - # ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V - # | Scheduled |------->| Running |---------------------->| paused | | - # └-----------┘ └-----------------------┘ └-----------┘ | - # | | | | | | | - # | | | | | | | - # ^ | | | V V | - # | | | | | ┌---------┐ | - # (@) | | | └------------------------>| Stopped |<----┘ - # | | | └---------┘ - # | | | | - # | | V V - # | | ┌-----------┐ | - # | | | Succeeded |------------->--------------┤ - # | | └-----------┘ | - # | V V - # | +--------┐ | - # | | Failed |---------------------->----------------┤ - # | └--------┘ | - # V V - # ┌---------------------┐ | - # | Partially Succeeded |---------------------->-----------------┘--------> (*) - # └---------------------┘ - # - # Mermaid diagram: - # - # --- - # title: State diagram for Workflow run state - # --- - # stateDiagram-v2 - # scheduled: Scheduled - # running: Running - # succeeded: Succeeded - # failed: Failed - # partial_succeeded: Partial Succeeded - # paused: Paused - # stopped: Stopped - # - # [*] --> scheduled: - # scheduled --> running: Start Execution - # running --> paused: Human input required - # paused --> running: human input added - # paused --> stopped: User stops execution - # running --> succeeded: Execution finishes without any error - # running --> failed: Execution finishes with errors - # running --> stopped: User stops execution - # running --> partial_succeeded: some execution occurred and handled during execution - # - # scheduled --> stopped: User stops execution - # - # succeeded --> [*] - # failed --> [*] - # partial_succeeded --> [*] - # stopped --> [*] - - # `SCHEDULED` means that the workflow is scheduled to run, but has not - # started running yet. (maybe due to possible worker saturation.) - # - # This enum value is currently unused. - SCHEDULED = "scheduled" - - # `RUNNING` means the workflow is exeuting. - RUNNING = "running" - - # `SUCCEEDED` means the execution of workflow succeed without any error. - SUCCEEDED = "succeeded" - - # `FAILED` means the execution of workflow failed without some errors. - FAILED = "failed" - - # `STOPPED` means the execution of workflow was stopped, either manually - # by the user, or automatically by the Dify application (E.G. the moderation - # mechanism.) - STOPPED = "stopped" - - # `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow - # execution, but they were successfully handled (e.g., by using an error - # strategy such as "fail branch" or "default value"). - PARTIAL_SUCCEEDED = "partial-succeeded" - - # `PAUSED` indicates that the workflow execution is temporarily paused - # (e.g., awaiting human input) and is expected to resume later. - PAUSED = "paused" - - def is_ended(self) -> bool: - return self in _END_STATE - - @classmethod - def ended_values(cls) -> list[str]: - return [status.value for status in _END_STATE] - - -_END_STATE = frozenset( - [ - WorkflowExecutionStatus.SUCCEEDED, - WorkflowExecutionStatus.FAILED, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED, - WorkflowExecutionStatus.STOPPED, - ] -) - - -class WorkflowNodeExecutionMetadataKey(StrEnum): - """ - Node Run Metadata Key. - - Values in this enum are persisted as execution metadata and must stay in sync - with every node that writes `NodeRunResult.metadata`. - """ - - TOTAL_TOKENS = "total_tokens" - TOTAL_PRICE = "total_price" - CURRENCY = "currency" - TOOL_INFO = "tool_info" - AGENT_LOG = "agent_log" - ITERATION_ID = "iteration_id" - ITERATION_INDEX = "iteration_index" - LOOP_ID = "loop_id" - LOOP_INDEX = "loop_index" - PARALLEL_ID = "parallel_id" - PARALLEL_START_NODE_ID = "parallel_start_node_id" - PARENT_PARALLEL_ID = "parent_parallel_id" - PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" - PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" - ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs - LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs - ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field - LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output - DATASOURCE_INFO = "datasource_info" - TRIGGER_INFO = "trigger_info" - COMPLETED_REASON = "completed_reason" # completed reason for loop node - - -class WorkflowNodeExecutionStatus(StrEnum): - PENDING = "pending" # Node is scheduled but not yet executing - RUNNING = "running" - SUCCEEDED = "succeeded" - FAILED = "failed" - EXCEPTION = "exception" - STOPPED = "stopped" - PAUSED = "paused" - - # Legacy statuses - kept for backward compatibility - RETRY = "retry" # Legacy: replaced by retry mechanism in error handling diff --git a/api/graphon/errors.py b/api/graphon/errors.py deleted file mode 100644 index 7eb007524d..0000000000 --- a/api/graphon/errors.py +++ /dev/null @@ -1,16 +0,0 @@ -from graphon.nodes.base.node import Node - - -class WorkflowNodeRunFailedError(Exception): - def __init__(self, node: Node, err_msg: str): - self._node = node - self._error = err_msg - super().__init__(f"Node {node.title} run failed: {err_msg}") - - @property - def node(self) -> Node: - return self._node - - @property - def error(self) -> str: - return self._error diff --git a/api/graphon/file/__init__.py b/api/graphon/file/__init__.py deleted file mode 100644 index 4908ae9795..0000000000 --- a/api/graphon/file/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .constants import FILE_MODEL_IDENTITY -from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType -from .file_factory import get_file_type_by_mime_type, standardize_file_type -from .models import ( - File, - FileUploadConfig, - ImageConfig, -) - -__all__ = [ - "FILE_MODEL_IDENTITY", - "ArrayFileAttribute", - "File", - "FileAttribute", - "FileBelongsTo", - "FileTransferMethod", - "FileType", - "FileUploadConfig", - "ImageConfig", - "get_file_type_by_mime_type", - "standardize_file_type", -] diff --git a/api/graphon/file/constants.py b/api/graphon/file/constants.py deleted file mode 100644 index 56b95b5f0d..0000000000 --- a/api/graphon/file/constants.py +++ /dev/null @@ -1,48 +0,0 @@ -from collections.abc import Iterable -from typing import Any - -# TODO(QuantumGhost): Refactor variable type identification. Instead of directly -# comparing `dify_model_identity` with constants throughout the codebase, extract -# this logic into a dedicated function. This would encapsulate the implementation -# details of how different variable types are identified. -FILE_MODEL_IDENTITY = "__dify__file__" -DEFAULT_MIME_TYPE = "application/octet-stream" -DEFAULT_EXTENSION = ".bin" - - -def _with_case_variants(extensions: Iterable[str]) -> frozenset[str]: - normalized = {extension.lower() for extension in extensions} - return frozenset(normalized | {extension.upper() for extension in normalized}) - - -IMAGE_EXTENSIONS = _with_case_variants({"jpg", "jpeg", "png", "webp", "gif", "svg"}) -VIDEO_EXTENSIONS = _with_case_variants({"mp4", "mov", "mpeg", "webm"}) -AUDIO_EXTENSIONS = _with_case_variants({"mp3", "m4a", "wav", "amr", "mpga"}) -DOCUMENT_EXTENSIONS = _with_case_variants( - { - "txt", - "markdown", - "md", - "mdx", - "pdf", - "html", - "htm", - "xlsx", - "xls", - "vtt", - "properties", - "doc", - "docx", - "csv", - "eml", - "msg", - "ppt", - "pptx", - "xml", - "epub", - } -) - - -def maybe_file_object(o: Any) -> bool: - return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY diff --git a/api/graphon/file/enums.py b/api/graphon/file/enums.py deleted file mode 100644 index 170eb4fc23..0000000000 --- a/api/graphon/file/enums.py +++ /dev/null @@ -1,57 +0,0 @@ -from enum import StrEnum - - -class FileType(StrEnum): - IMAGE = "image" - DOCUMENT = "document" - AUDIO = "audio" - VIDEO = "video" - CUSTOM = "custom" - - @staticmethod - def value_of(value): - for member in FileType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileTransferMethod(StrEnum): - REMOTE_URL = "remote_url" - LOCAL_FILE = "local_file" - TOOL_FILE = "tool_file" - DATASOURCE_FILE = "datasource_file" - - @staticmethod - def value_of(value): - for member in FileTransferMethod: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileBelongsTo(StrEnum): - USER = "user" - ASSISTANT = "assistant" - - @staticmethod - def value_of(value): - for member in FileBelongsTo: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileAttribute(StrEnum): - TYPE = "type" - SIZE = "size" - NAME = "name" - MIME_TYPE = "mime_type" - TRANSFER_METHOD = "transfer_method" - URL = "url" - EXTENSION = "extension" - RELATED_ID = "related_id" - - -class ArrayFileAttribute(StrEnum): - LENGTH = "length" diff --git a/api/graphon/file/file_factory.py b/api/graphon/file/file_factory.py deleted file mode 100644 index 3d20b9377d..0000000000 --- a/api/graphon/file/file_factory.py +++ /dev/null @@ -1,39 +0,0 @@ -from .constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS -from .enums import FileType - - -def standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType: - """ - Infer the actual file type from extension and mime type. - """ - guessed_type = None - if extension: - guessed_type = _get_file_type_by_extension(extension) - if guessed_type is None and mime_type: - guessed_type = get_file_type_by_mime_type(mime_type) - return guessed_type or FileType.CUSTOM - - -def _get_file_type_by_extension(extension: str) -> FileType | None: - normalized_extension = extension.lstrip(".") - if normalized_extension in IMAGE_EXTENSIONS: - return FileType.IMAGE - if normalized_extension in VIDEO_EXTENSIONS: - return FileType.VIDEO - if normalized_extension in AUDIO_EXTENSIONS: - return FileType.AUDIO - if normalized_extension in DOCUMENT_EXTENSIONS: - return FileType.DOCUMENT - return None - - -def get_file_type_by_mime_type(mime_type: str) -> FileType: - if "image" in mime_type: - return FileType.IMAGE - if "video" in mime_type: - return FileType.VIDEO - if "audio" in mime_type: - return FileType.AUDIO - if "text" in mime_type or "pdf" in mime_type: - return FileType.DOCUMENT - return FileType.CUSTOM diff --git a/api/graphon/file/file_manager.py b/api/graphon/file/file_manager.py deleted file mode 100644 index d7e4d472e7..0000000000 --- a/api/graphon/file/file_manager.py +++ /dev/null @@ -1,129 +0,0 @@ -from __future__ import annotations - -import base64 -from collections.abc import Mapping - -from graphon.model_runtime.entities import ( - AudioPromptMessageContent, - DocumentPromptMessageContent, - ImagePromptMessageContent, - TextPromptMessageContent, - VideoPromptMessageContent, -) -from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes - -from .enums import FileAttribute -from .models import File, FileTransferMethod, FileType -from .runtime import get_workflow_file_runtime - - -def get_attr(*, file: File, attr: FileAttribute): - match attr: - case FileAttribute.TYPE: - return file.type.value - case FileAttribute.SIZE: - return file.size - case FileAttribute.NAME: - return file.filename - case FileAttribute.MIME_TYPE: - return file.mime_type - case FileAttribute.TRANSFER_METHOD: - return file.transfer_method.value - case FileAttribute.URL: - return _to_url(file) - case FileAttribute.EXTENSION: - return file.extension - case FileAttribute.RELATED_ID: - return file.related_id - - -def to_prompt_message_content( - f: File, - /, - *, - image_detail_config: ImagePromptMessageContent.DETAIL | None = None, -) -> PromptMessageContentUnionTypes: - """Convert a file to prompt message content.""" - if f.extension is None: - raise ValueError("Missing file extension") - if f.mime_type is None: - raise ValueError("Missing file mime_type") - - prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = { - FileType.IMAGE: ImagePromptMessageContent, - FileType.AUDIO: AudioPromptMessageContent, - FileType.VIDEO: VideoPromptMessageContent, - FileType.DOCUMENT: DocumentPromptMessageContent, - } - - if f.type not in prompt_class_map: - return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]") - - send_format = get_workflow_file_runtime().multimodal_send_format - params = { - "base64_data": _get_encoded_string(f) if send_format == "base64" else "", - "url": _to_url(f) if send_format == "url" else "", - "format": f.extension.removeprefix("."), - "mime_type": f.mime_type, - "filename": f.filename or "", - } - if f.type == FileType.IMAGE: - params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - - return prompt_class_map[f.type].model_validate(params) - - -def download(f: File, /) -> bytes: - if f.transfer_method in ( - FileTransferMethod.TOOL_FILE, - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.DATASOURCE_FILE, - ): - return _download_file_content(f) - elif f.transfer_method == FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True) - response.raise_for_status() - return response.content - raise ValueError(f"unsupported transfer method: {f.transfer_method}") - - -def _download_file_content(file: File, /) -> bytes: - """Download and return a file from storage as bytes.""" - return get_workflow_file_runtime().load_file_bytes(file=file) - - -def _get_encoded_string(f: File, /) -> str: - match f.transfer_method: - case FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True) - response.raise_for_status() - data = response.content - case FileTransferMethod.LOCAL_FILE: - data = _download_file_content(f) - case FileTransferMethod.TOOL_FILE: - data = _download_file_content(f) - case FileTransferMethod.DATASOURCE_FILE: - data = _download_file_content(f) - - return base64.b64encode(data).decode("utf-8") - - -def _to_url(f: File, /): - url = f.generate_url() - if url is None: - raise ValueError(f"Unsupported transfer method: {f.transfer_method}") - return url - - -class FileManager: - """Adapter exposing file manager helpers behind FileManagerProtocol.""" - - def download(self, f: File, /) -> bytes: - return download(f) - - -file_manager = FileManager() diff --git a/api/graphon/file/helpers.py b/api/graphon/file/helpers.py deleted file mode 100644 index dade761227..0000000000 --- a/api/graphon/file/helpers.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from .runtime import get_workflow_file_runtime - -if TYPE_CHECKING: - from .models import File - - -def resolve_file_url(file: File, /, *, for_external: bool = True) -> str | None: - return get_workflow_file_runtime().resolve_file_url(file=file, for_external=for_external) - - -def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str: - return get_workflow_file_runtime().resolve_upload_file_url( - upload_file_id=upload_file_id, - as_attachment=as_attachment, - for_external=for_external, - ) - - -def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str: - return get_workflow_file_runtime().resolve_tool_file_url( - tool_file_id=tool_file_id, - extension=extension, - for_external=for_external, - ) - - -def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - return get_workflow_file_runtime().verify_preview_signature( - preview_kind="image", - file_id=upload_file_id, - timestamp=timestamp, - nonce=nonce, - sign=sign, - ) - - -def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - return get_workflow_file_runtime().verify_preview_signature( - preview_kind="file", - file_id=upload_file_id, - timestamp=timestamp, - nonce=nonce, - sign=sign, - ) diff --git a/api/graphon/file/models.py b/api/graphon/file/models.py deleted file mode 100644 index ccd7584371..0000000000 --- a/api/graphon/file/models.py +++ /dev/null @@ -1,215 +0,0 @@ -from __future__ import annotations - -import base64 -import json -from collections.abc import Mapping, Sequence -from typing import Any - -from pydantic import BaseModel, Field, model_validator - -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent - -from . import helpers -from .constants import FILE_MODEL_IDENTITY -from .enums import FileTransferMethod, FileType - -_FILE_REFERENCE_PREFIX = "dify-file-ref:" - - -def sign_tool_file(*, tool_file_id: str, extension: str, for_external: bool = True) -> str: - """Compatibility shim for tests and legacy callers patching ``models.sign_tool_file``.""" - return helpers.get_signed_tool_file_url( - tool_file_id=tool_file_id, - extension=extension, - for_external=for_external, - ) - - -class ImageConfig(BaseModel): - """ - NOTE: This part of validation is deprecated, but still used in app features "Image Upload". - """ - - number_limits: int = 0 - transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list) - detail: ImagePromptMessageContent.DETAIL | None = None - - -class FileUploadConfig(BaseModel): - """ - File Upload Entity. - """ - - image_config: ImageConfig | None = None - allowed_file_types: Sequence[FileType] = Field(default_factory=list) - allowed_file_extensions: Sequence[str] = Field(default_factory=list) - allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) - number_limits: int = 0 - - -def _parse_reference(reference: str | None) -> tuple[str | None, str | None]: - """Best-effort parser for record references and historical storage-key payloads.""" - if not reference: - return None, None - - if not reference.startswith(_FILE_REFERENCE_PREFIX): - return reference, None - - encoded_payload = reference.removeprefix(_FILE_REFERENCE_PREFIX) - try: - payload = json.loads(base64.urlsafe_b64decode(encoded_payload.encode())) - except (ValueError, json.JSONDecodeError): - return reference, None - - record_id = payload.get("record_id") - if not isinstance(record_id, str) or not record_id: - return reference, None - - storage_key = payload.get("storage_key") - if not isinstance(storage_key, str): - storage_key = None - - return record_id, storage_key - - -class File(BaseModel): - """Graph-owned file reference. - - The graph layer deliberately keeps only the metadata required to route, - serialize, and render files. Application ownership concerns such as - tenant/user/conversation identity stay in the workflow/storage layer. - """ - - # NOTE: dify_model_identity is a special identifier used to distinguish between - # new and old data formats during serialization and deserialization. - dify_model_identity: str = FILE_MODEL_IDENTITY - - id: str | None = None # message file id - type: FileType - transfer_method: FileTransferMethod - # If `transfer_method` is `FileTransferMethod.remote_url`, the - # `remote_url` attribute must not be `None`. - remote_url: str | None = None # remote url - # Opaque workflow-layer reference for files resolved outside ``graphon``. - # New payloads only carry the backing record id; historical payloads may - # still include storage_key and must remain readable. - reference: str | None = None - filename: str | None = None - extension: str | None = Field(default=None, description="File extension, should contain dot") - mime_type: str | None = None - size: int = -1 - _storage_key: str - - def __init__( - self, - *, - id: str | None = None, - tenant_id: str | None = None, - type: FileType, - transfer_method: FileTransferMethod, - remote_url: str | None = None, - reference: str | None = None, - related_id: str | None = None, - filename: str | None = None, - extension: str | None = None, - mime_type: str | None = None, - size: int = -1, - storage_key: str | None = None, - dify_model_identity: str | None = FILE_MODEL_IDENTITY, - url: str | None = None, - # Legacy compatibility fields - explicitly accept known extra fields - tool_file_id: str | None = None, - upload_file_id: str | None = None, - datasource_file_id: str | None = None, - ): - legacy_record_id = related_id or tool_file_id or upload_file_id or datasource_file_id - normalized_reference = reference - if normalized_reference is None and legacy_record_id is not None: - normalized_reference = str(legacy_record_id) - _, parsed_storage_key = _parse_reference(normalized_reference) - - super().__init__( - id=id, - type=type, - transfer_method=transfer_method, - remote_url=remote_url, - reference=normalized_reference, - filename=filename, - extension=extension, - mime_type=mime_type, - size=size, - dify_model_identity=dify_model_identity, - url=url, - ) - # Accept legacy constructor fields without promoting them back into the graph model. - _ = tenant_id - self._storage_key = storage_key or parsed_storage_key or "" - - def to_dict(self) -> Mapping[str, str | int | None]: - data = self.model_dump(mode="json") - return { - **data, - "related_id": self.related_id, - "url": self.generate_url(), - } - - @property - def markdown(self) -> str: - url = self.generate_url() - if self.type == FileType.IMAGE: - text = f"![{self.filename or ''}]({url})" - else: - text = f"[{self.filename or url}]({url})" - - return text - - def generate_url(self, for_external: bool = True) -> str | None: - return helpers.resolve_file_url(self, for_external=for_external) - - def to_plugin_parameter(self) -> dict[str, Any]: - return { - "dify_model_identity": FILE_MODEL_IDENTITY, - "mime_type": self.mime_type, - "filename": self.filename, - "extension": self.extension, - "size": self.size, - "type": self.type, - "url": self.generate_url(for_external=False), - } - - @model_validator(mode="after") - def validate_after(self) -> File: - match self.transfer_method: - case FileTransferMethod.REMOTE_URL: - if not self.remote_url: - raise ValueError("Missing file url") - if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"): - raise ValueError("Invalid file url") - case FileTransferMethod.LOCAL_FILE: - if not self.reference: - raise ValueError("Missing file reference") - case FileTransferMethod.TOOL_FILE: - if not self.reference: - raise ValueError("Missing file reference") - case FileTransferMethod.DATASOURCE_FILE: - if not self.reference: - raise ValueError("Missing file reference") - return self - - @property - def related_id(self) -> str | None: - record_id, _ = _parse_reference(self.reference) - return record_id - - @related_id.setter - def related_id(self, value: str | None) -> None: - self.reference = value - - @property - def storage_key(self) -> str: - _, storage_key = _parse_reference(self.reference) - return storage_key or self._storage_key - - @storage_key.setter - def storage_key(self, value: str) -> None: - self._storage_key = value diff --git a/api/graphon/file/protocols.py b/api/graphon/file/protocols.py deleted file mode 100644 index 0acabe35e5..0000000000 --- a/api/graphon/file/protocols.py +++ /dev/null @@ -1,56 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator -from typing import TYPE_CHECKING, Literal, Protocol - -if TYPE_CHECKING: - from .models import File - - -class HttpResponseProtocol(Protocol): - """Subset of response behavior needed by workflow file helpers.""" - - @property - def content(self) -> bytes: ... - - def raise_for_status(self) -> object: ... - - -class WorkflowFileRuntimeProtocol(Protocol): - """Runtime dependencies required by ``graphon.file``. - - Implementations are expected to be provided by integration layers (for example, - ``core.app.workflow.file_runtime``) so the workflow package avoids importing - application infrastructure modules directly. - """ - - @property - def multimodal_send_format(self) -> str: ... - - def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: ... - - def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ... - - def load_file_bytes(self, *, file: File) -> bytes: ... - - def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: ... - - def resolve_upload_file_url( - self, - *, - upload_file_id: str, - as_attachment: bool = False, - for_external: bool = True, - ) -> str: ... - - def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ... - - def verify_preview_signature( - self, - *, - preview_kind: Literal["image", "file"], - file_id: str, - timestamp: str, - nonce: str, - sign: str, - ) -> bool: ... diff --git a/api/graphon/file/runtime.py b/api/graphon/file/runtime.py deleted file mode 100644 index 1c5d1c3ca4..0000000000 --- a/api/graphon/file/runtime.py +++ /dev/null @@ -1,71 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator -from typing import TYPE_CHECKING, Literal, NoReturn - -from .protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol - -if TYPE_CHECKING: - from .models import File - - -class WorkflowFileRuntimeNotConfiguredError(RuntimeError): - """Raised when workflow file runtime dependencies were not configured.""" - - -class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol): - def _raise(self) -> NoReturn: - raise WorkflowFileRuntimeNotConfiguredError( - "workflow file runtime is not configured, call set_workflow_file_runtime(...) first" - ) - - @property - def multimodal_send_format(self) -> str: - self._raise() - - def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: - self._raise() - - def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: - self._raise() - - def load_file_bytes(self, *, file: File) -> bytes: - self._raise() - - def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: - self._raise() - - def resolve_upload_file_url( - self, - *, - upload_file_id: str, - as_attachment: bool = False, - for_external: bool = True, - ) -> str: - self._raise() - - def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: - self._raise() - - def verify_preview_signature( - self, - *, - preview_kind: Literal["image", "file"], - file_id: str, - timestamp: str, - nonce: str, - sign: str, - ) -> bool: - self._raise() - - -_runtime: WorkflowFileRuntimeProtocol = _UnconfiguredWorkflowFileRuntime() - - -def set_workflow_file_runtime(runtime: WorkflowFileRuntimeProtocol) -> None: - global _runtime - _runtime = runtime - - -def get_workflow_file_runtime() -> WorkflowFileRuntimeProtocol: - return _runtime diff --git a/api/graphon/file/tool_file_parser.py b/api/graphon/file/tool_file_parser.py deleted file mode 100644 index 2d7a3d43df..0000000000 --- a/api/graphon/file/tool_file_parser.py +++ /dev/null @@ -1,9 +0,0 @@ -from collections.abc import Callable -from typing import Any - -_tool_file_manager_factory: Callable[[], Any] | None = None - - -def set_tool_file_manager_factory(factory: Callable[[], Any]): - global _tool_file_manager_factory - _tool_file_manager_factory = factory diff --git a/api/graphon/graph/__init__.py b/api/graphon/graph/__init__.py deleted file mode 100644 index 4830ea83d3..0000000000 --- a/api/graphon/graph/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .edge import Edge -from .graph import Graph, GraphBuilder, NodeFactory -from .graph_template import GraphTemplate - -__all__ = [ - "Edge", - "Graph", - "GraphBuilder", - "GraphTemplate", - "NodeFactory", -] diff --git a/api/graphon/graph/edge.py b/api/graphon/graph/edge.py deleted file mode 100644 index 1f8a2884e3..0000000000 --- a/api/graphon/graph/edge.py +++ /dev/null @@ -1,15 +0,0 @@ -import uuid -from dataclasses import dataclass, field - -from graphon.enums import NodeState - - -@dataclass -class Edge: - """Edge connecting two nodes in a workflow graph.""" - - id: str = field(default_factory=lambda: str(uuid.uuid4())) - tail: str = "" # tail node id (source) - head: str = "" # head node id (target) - source_handle: str = "source" # source handle for conditional branching - state: NodeState = field(default=NodeState.UNKNOWN) # edge execution state diff --git a/api/graphon/graph/graph.py b/api/graphon/graph/graph.py deleted file mode 100644 index 0f4cd8925f..0000000000 --- a/api/graphon/graph/graph.py +++ /dev/null @@ -1,438 +0,0 @@ -from __future__ import annotations - -import logging -from collections import defaultdict -from collections.abc import Mapping, Sequence -from typing import Protocol, cast, final - -from pydantic import TypeAdapter - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ErrorStrategy, NodeExecutionType, NodeState -from graphon.nodes.base.node import Node - -from .edge import Edge -from .validation import get_graph_validator - -logger = logging.getLogger(__name__) - -_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict]) - - -class NodeFactory(Protocol): - """ - Protocol for creating Node instances from node data dictionaries. - - This protocol decouples the Graph class from specific node mapping implementations, - allowing for different node creation strategies while maintaining type safety. - """ - - def create_node(self, node_config: NodeConfigDict) -> Node: - """ - Create a Node instance from node configuration data. - - :param node_config: node configuration dictionary containing type and other data - :return: initialized Node instance - :raises ValueError: if node type is unknown or no implementation exists for the resolved version - :raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation - """ - ... - - -@final -class Graph: - """Graph representation with nodes and edges for workflow execution.""" - - def __init__( - self, - *, - nodes: dict[str, Node] | None = None, - edges: dict[str, Edge] | None = None, - in_edges: dict[str, list[str]] | None = None, - out_edges: dict[str, list[str]] | None = None, - root_node: Node, - ): - """ - Initialize Graph instance. - - :param nodes: graph nodes mapping (node id: node object) - :param edges: graph edges mapping (edge id: edge object) - :param in_edges: incoming edges mapping (node id: list of edge ids) - :param out_edges: outgoing edges mapping (node id: list of edge ids) - :param root_node: root node object - """ - self.nodes = nodes or {} - self.edges = edges or {} - self.in_edges = in_edges or {} - self.out_edges = out_edges or {} - self.root_node = root_node - - @classmethod - def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]: - """ - Parse node configurations and build a mapping of node IDs to configs. - - :param node_configs: list of node configuration dictionaries - :return: mapping of node ID to node config - """ - node_configs_map: dict[str, NodeConfigDict] = {} - - for node_config in node_configs: - node_configs_map[node_config["id"]] = node_config - - return node_configs_map - - @classmethod - def _build_edges( - cls, edge_configs: list[dict[str, object]] - ) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]: - """ - Build edge objects and mappings from edge configurations. - - :param edge_configs: list of edge configurations - :return: tuple of (edges dict, in_edges dict, out_edges dict) - """ - edges: dict[str, Edge] = {} - in_edges: dict[str, list[str]] = defaultdict(list) - out_edges: dict[str, list[str]] = defaultdict(list) - - edge_counter = 0 - for edge_config in edge_configs: - source = edge_config.get("source") - target = edge_config.get("target") - - if not isinstance(source, str) or not isinstance(target, str): - continue - - # Create edge - edge_id = f"edge_{edge_counter}" - edge_counter += 1 - - source_handle = edge_config.get("sourceHandle", "source") - if not isinstance(source_handle, str): - continue - - edge = Edge( - id=edge_id, - tail=source, - head=target, - source_handle=source_handle, - ) - - edges[edge_id] = edge - out_edges[source].append(edge_id) - in_edges[target].append(edge_id) - - return edges, dict(in_edges), dict(out_edges) - - @classmethod - def _create_node_instances( - cls, - node_configs_map: dict[str, NodeConfigDict], - node_factory: NodeFactory, - ) -> dict[str, Node]: - """ - Create node instances from configurations using the node factory. - - :param node_configs_map: mapping of node ID to node config - :param node_factory: factory for creating node instances - :return: mapping of node ID to node instance - """ - nodes: dict[str, Node] = {} - - for node_id, node_config in node_configs_map.items(): - try: - node_instance = node_factory.create_node(node_config) - except Exception: - logger.exception("Failed to create node instance for node_id %s", node_id) - raise - nodes[node_id] = node_instance - - return nodes - - @classmethod - def new(cls) -> GraphBuilder: - """Create a fluent builder for assembling a graph programmatically.""" - - return GraphBuilder(graph_cls=cls) - - @staticmethod - def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]: - """ - Remove editor-only nodes before `NodeConfigDict` validation. - - Persisted note widgets use a top-level `type == "custom-note"` but leave - `data.type` empty because they are never executable graph nodes. Filter - them while configs are still raw dicts so Pydantic does not validate - their placeholder payloads against `BaseNodeData.type: NodeType`. - """ - filtered_node_configs: list[dict[str, object]] = [] - for node_config in node_configs: - if node_config.get("type", "") == "custom-note": - continue - filtered_node_configs.append(dict(node_config)) - return filtered_node_configs - - @classmethod - def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None: - """ - Promote nodes configured with FAIL_BRANCH error strategy to branch execution type. - - :param nodes: mapping of node ID to node instance - """ - for node in nodes.values(): - if node.error_strategy == ErrorStrategy.FAIL_BRANCH: - node.execution_type = NodeExecutionType.BRANCH - - @classmethod - def _mark_inactive_root_branches( - cls, - nodes: dict[str, Node], - edges: dict[str, Edge], - in_edges: dict[str, list[str]], - out_edges: dict[str, list[str]], - active_root_id: str, - ) -> None: - """ - Mark nodes and edges from inactive root branches as skipped. - - Algorithm: - 1. Mark inactive root nodes as skipped - 2. For skipped nodes, mark all their outgoing edges as skipped - 3. For each edge marked as skipped, check its target node: - - If ALL incoming edges are skipped, mark the node as skipped - - Otherwise, leave the node state unchanged - - :param nodes: mapping of node ID to node instance - :param edges: mapping of edge ID to edge instance - :param in_edges: mapping of node ID to incoming edge IDs - :param out_edges: mapping of node ID to outgoing edge IDs - :param active_root_id: ID of the active root node - """ - # Find all top-level root nodes (nodes with ROOT execution type and no incoming edges) - top_level_roots: list[str] = [ - node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT - ] - - # If there's only one root or the active root is not a top-level root, no marking needed - if len(top_level_roots) <= 1 or active_root_id not in top_level_roots: - return - - # Mark inactive root nodes as skipped - inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id] - for root_id in inactive_roots: - if root_id in nodes: - nodes[root_id].state = NodeState.SKIPPED - - # Recursively mark downstream nodes and edges - def mark_downstream(node_id: str) -> None: - """Recursively mark downstream nodes and edges as skipped.""" - if nodes[node_id].state != NodeState.SKIPPED: - return - # If this node is skipped, mark all its outgoing edges as skipped - out_edge_ids = out_edges.get(node_id, []) - for edge_id in out_edge_ids: - edge = edges[edge_id] - edge.state = NodeState.SKIPPED - - # Check the target node of this edge - target_node = nodes[edge.head] - in_edge_ids = in_edges.get(target_node.id, []) - in_edge_states = [edges[eid].state for eid in in_edge_ids] - - # If all incoming edges are skipped, mark the node as skipped - if all(state == NodeState.SKIPPED for state in in_edge_states): - target_node.state = NodeState.SKIPPED - # Recursively process downstream nodes - mark_downstream(target_node.id) - - # Process each inactive root and its downstream nodes - for root_id in inactive_roots: - mark_downstream(root_id) - - @classmethod - def init( - cls, - *, - graph_config: Mapping[str, object], - node_factory: NodeFactory, - root_node_id: str, - skip_validation: bool = False, - ) -> Graph: - """ - Initialize a graph with an explicit execution entry point. - - :param graph_config: graph config containing nodes and edges - :param node_factory: factory for creating node instances from config data - :param root_node_id: active root node id - :return: graph instance - """ - # Parse configs - edge_configs = graph_config.get("edges", []) - node_configs = graph_config.get("nodes", []) - - edge_configs = cast(list[dict[str, object]], edge_configs) - node_configs = cast(list[dict[str, object]], node_configs) - node_configs = cls._filter_canvas_only_nodes(node_configs) - node_configs = _ListNodeConfigDict.validate_python(node_configs) - - if not node_configs: - raise ValueError("Graph must have at least one node") - - # Parse node configurations - node_configs_map = cls._parse_node_configs(node_configs) - - if root_node_id not in node_configs_map: - raise ValueError(f"Root node id {root_node_id} not found in the graph") - - # Build edges - edges, in_edges, out_edges = cls._build_edges(edge_configs) - - # Create node instances - nodes = cls._create_node_instances(node_configs_map, node_factory) - - # Promote fail-branch nodes to branch execution type at graph level - cls._promote_fail_branch_nodes(nodes) - - # Get root node instance - root_node = nodes[root_node_id] - - # Mark inactive root branches as skipped - cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id) - - # Create and return the graph - graph = cls( - nodes=nodes, - edges=edges, - in_edges=in_edges, - out_edges=out_edges, - root_node=root_node, - ) - - if not skip_validation: - # Validate the graph structure using built-in validators - get_graph_validator().validate(graph) - - return graph - - @property - def node_ids(self) -> list[str]: - """ - Get list of node IDs (compatibility property for existing code) - - :return: list of node IDs - """ - return list(self.nodes.keys()) - - def get_outgoing_edges(self, node_id: str) -> list[Edge]: - """ - Get all outgoing edges from a node (V2 method) - - :param node_id: node id - :return: list of outgoing edges - """ - edge_ids = self.out_edges.get(node_id, []) - return [self.edges[eid] for eid in edge_ids if eid in self.edges] - - def get_incoming_edges(self, node_id: str) -> list[Edge]: - """ - Get all incoming edges to a node (V2 method) - - :param node_id: node id - :return: list of incoming edges - """ - edge_ids = self.in_edges.get(node_id, []) - return [self.edges[eid] for eid in edge_ids if eid in self.edges] - - -@final -class GraphBuilder: - """Fluent helper for constructing simple graphs, primarily for tests.""" - - def __init__(self, *, graph_cls: type[Graph]): - self._graph_cls = graph_cls - self._nodes: list[Node] = [] - self._nodes_by_id: dict[str, Node] = {} - self._edges: list[Edge] = [] - self._edge_counter = 0 - - def add_root(self, node: Node) -> GraphBuilder: - """Register the root node. Must be called exactly once.""" - - if self._nodes: - raise ValueError("Root node has already been added") - self._register_node(node) - self._nodes.append(node) - return self - - def add_node( - self, - node: Node, - *, - from_node_id: str | None = None, - source_handle: str = "source", - ) -> GraphBuilder: - """Append a node and connect it from the specified predecessor.""" - - if not self._nodes: - raise ValueError("Root node must be added before adding other nodes") - - predecessor_id = from_node_id or self._nodes[-1].id - if predecessor_id not in self._nodes_by_id: - raise ValueError(f"Predecessor node '{predecessor_id}' not found") - - predecessor = self._nodes_by_id[predecessor_id] - self._register_node(node) - self._nodes.append(node) - - edge_id = f"edge_{self._edge_counter}" - self._edge_counter += 1 - edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle) - self._edges.append(edge) - - return self - - def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder: - """Connect two existing nodes without adding a new node.""" - - if tail not in self._nodes_by_id: - raise ValueError(f"Tail node '{tail}' not found") - if head not in self._nodes_by_id: - raise ValueError(f"Head node '{head}' not found") - - edge_id = f"edge_{self._edge_counter}" - self._edge_counter += 1 - edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle) - self._edges.append(edge) - - return self - - def build(self) -> Graph: - """Materialize the graph instance from the accumulated nodes and edges.""" - - if not self._nodes: - raise ValueError("Cannot build an empty graph") - - nodes = {node.id: node for node in self._nodes} - edges = {edge.id: edge for edge in self._edges} - in_edges: dict[str, list[str]] = defaultdict(list) - out_edges: dict[str, list[str]] = defaultdict(list) - - for edge in self._edges: - out_edges[edge.tail].append(edge.id) - in_edges[edge.head].append(edge.id) - - return self._graph_cls( - nodes=nodes, - edges=edges, - in_edges=dict(in_edges), - out_edges=dict(out_edges), - root_node=self._nodes[0], - ) - - def _register_node(self, node: Node) -> None: - if not node.id: - raise ValueError("Node must have a non-empty id") - if node.id in self._nodes_by_id: - raise ValueError(f"Duplicate node id detected: {node.id}") - self._nodes_by_id[node.id] = node diff --git a/api/graphon/graph/graph_template.py b/api/graphon/graph/graph_template.py deleted file mode 100644 index 34e2dc19e6..0000000000 --- a/api/graphon/graph/graph_template.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Any - -from pydantic import BaseModel, Field - - -class GraphTemplate(BaseModel): - """ - Graph Template for container nodes and subgraph expansion - - According to GraphEngine V2 spec, GraphTemplate contains: - - nodes: mapping of node definitions - - edges: mapping of edge definitions - - root_ids: list of root node IDs - - output_selectors: list of output selectors for the template - """ - - nodes: dict[str, dict[str, Any]] = Field(default_factory=dict, description="node definitions mapping") - edges: dict[str, dict[str, Any]] = Field(default_factory=dict, description="edge definitions mapping") - root_ids: list[str] = Field(default_factory=list, description="root node IDs") - output_selectors: list[str] = Field(default_factory=list, description="output selectors") diff --git a/api/graphon/graph/validation.py b/api/graphon/graph/validation.py deleted file mode 100644 index 04b501fd33..0000000000 --- a/api/graphon/graph/validation.py +++ /dev/null @@ -1,125 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from dataclasses import dataclass -from typing import TYPE_CHECKING, Protocol - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeType - -if TYPE_CHECKING: - from .graph import Graph - - -@dataclass(frozen=True, slots=True) -class GraphValidationIssue: - """Immutable value object describing a single validation issue.""" - - code: str - message: str - node_id: str | None = None - - -class GraphValidationError(ValueError): - """Raised when graph validation fails.""" - - def __init__(self, issues: Sequence[GraphValidationIssue]) -> None: - if not issues: - raise ValueError("GraphValidationError requires at least one issue.") - self.issues: tuple[GraphValidationIssue, ...] = tuple(issues) - message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues) - super().__init__(message) - - -class GraphValidationRule(Protocol): - """Protocol that individual validation rules must satisfy.""" - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - """Validate the provided graph and return any discovered issues.""" - ... - - -@dataclass(frozen=True, slots=True) -class _EdgeEndpointValidator: - """Ensures all edges reference existing nodes.""" - - missing_node_code: str = "MISSING_NODE" - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - issues: list[GraphValidationIssue] = [] - for edge in graph.edges.values(): - if edge.tail not in graph.nodes: - issues.append( - GraphValidationIssue( - code=self.missing_node_code, - message=f"Edge {edge.id} references unknown source node '{edge.tail}'.", - node_id=edge.tail, - ) - ) - if edge.head not in graph.nodes: - issues.append( - GraphValidationIssue( - code=self.missing_node_code, - message=f"Edge {edge.id} references unknown target node '{edge.head}'.", - node_id=edge.head, - ) - ) - return issues - - -@dataclass(frozen=True, slots=True) -class _RootNodeValidator: - """Validates root node invariants.""" - - invalid_root_code: str = "INVALID_ROOT" - container_entry_types: tuple[NodeType, ...] = (BuiltinNodeTypes.ITERATION_START, BuiltinNodeTypes.LOOP_START) - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - root_node = graph.root_node - issues: list[GraphValidationIssue] = [] - if root_node.id not in graph.nodes: - issues.append( - GraphValidationIssue( - code=self.invalid_root_code, - message=f"Root node '{root_node.id}' is missing from the node registry.", - node_id=root_node.id, - ) - ) - return issues - - node_type = root_node.node_type - if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types: - issues.append( - GraphValidationIssue( - code=self.invalid_root_code, - message=f"Root node '{root_node.id}' must declare execution type 'root'.", - node_id=root_node.id, - ) - ) - return issues - - -@dataclass(frozen=True, slots=True) -class GraphValidator: - """Coordinates execution of graph validation rules.""" - - rules: tuple[GraphValidationRule, ...] - - def validate(self, graph: Graph) -> None: - """Validate the graph against all configured rules.""" - issues: list[GraphValidationIssue] = [] - for rule in self.rules: - issues.extend(rule.validate(graph)) - - if issues: - raise GraphValidationError(issues) - - -_DEFAULT_RULES: tuple[GraphValidationRule, ...] = ( - _EdgeEndpointValidator(), - _RootNodeValidator(), -) - - -def get_graph_validator() -> GraphValidator: - """Construct the validator composed of default rules.""" - return GraphValidator(_DEFAULT_RULES) diff --git a/api/graphon/graph_engine/__init__.py b/api/graphon/graph_engine/__init__.py deleted file mode 100644 index 0e1c7dd60a..0000000000 --- a/api/graphon/graph_engine/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .config import GraphEngineConfig -from .graph_engine import GraphEngine - -__all__ = ["GraphEngine", "GraphEngineConfig"] diff --git a/api/graphon/graph_engine/_engine_utils.py b/api/graphon/graph_engine/_engine_utils.py deleted file mode 100644 index 28898268fe..0000000000 --- a/api/graphon/graph_engine/_engine_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -import time - - -def get_timestamp() -> float: - """Retrieve a timestamp as a float point numer representing the number of seconds - since the Unix epoch. - - This function is primarily used to measure the execution time of the workflow engine. - Since workflow execution may be paused and resumed on a different machine, - `time.perf_counter` cannot be used as it is inconsistent across machines. - - To address this, the function uses the wall clock as the time source. - However, it assumes that the clocks of all servers are properly synchronized. - """ - return round(time.time()) diff --git a/api/graphon/graph_engine/command_channels/README.md b/api/graphon/graph_engine/command_channels/README.md deleted file mode 100644 index e35e12054a..0000000000 --- a/api/graphon/graph_engine/command_channels/README.md +++ /dev/null @@ -1,33 +0,0 @@ -# Command Channels - -Channel implementations for external workflow control. - -## Components - -### InMemoryChannel - -Thread-safe in-memory queue for single-process deployments. - -- `fetch_commands()` - Get pending commands -- `send_command()` - Add command to queue - -### RedisChannel - -Redis-based queue for distributed deployments. - -- `fetch_commands()` - Get commands with JSON deserialization -- `send_command()` - Store commands with TTL - -## Usage - -```python -# Local execution -channel = InMemoryChannel() -channel.send_command(AbortCommand(graph_id="workflow-123")) - -# Distributed execution -redis_channel = RedisChannel( - redis_client=redis_client, - channel_key="workflow:123:commands" -) -``` diff --git a/api/graphon/graph_engine/command_channels/__init__.py b/api/graphon/graph_engine/command_channels/__init__.py deleted file mode 100644 index 863e6032d6..0000000000 --- a/api/graphon/graph_engine/command_channels/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Command channel implementations for GraphEngine.""" - -from .in_memory_channel import InMemoryChannel -from .redis_channel import RedisChannel - -__all__ = ["InMemoryChannel", "RedisChannel"] diff --git a/api/graphon/graph_engine/command_channels/in_memory_channel.py b/api/graphon/graph_engine/command_channels/in_memory_channel.py deleted file mode 100644 index bdaf236796..0000000000 --- a/api/graphon/graph_engine/command_channels/in_memory_channel.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -In-memory implementation of CommandChannel for local/testing scenarios. - -This implementation uses a thread-safe queue for command communication -within a single process. Each instance handles commands for one workflow execution. -""" - -from queue import Queue -from typing import final - -from ..entities.commands import GraphEngineCommand - - -@final -class InMemoryChannel: - """ - In-memory command channel implementation using a thread-safe queue. - - Each instance is dedicated to a single GraphEngine/workflow execution. - Suitable for local development, testing, and single-instance deployments. - """ - - def __init__(self) -> None: - """Initialize the in-memory channel with a single queue.""" - self._queue: Queue[GraphEngineCommand] = Queue() - - def fetch_commands(self) -> list[GraphEngineCommand]: - """ - Fetch all pending commands from the queue. - - Returns: - List of pending commands (drains the queue) - """ - commands: list[GraphEngineCommand] = [] - - # Drain all available commands from the queue - while not self._queue.empty(): - try: - command = self._queue.get_nowait() - commands.append(command) - except Exception: - break - - return commands - - def send_command(self, command: GraphEngineCommand) -> None: - """ - Send a command to this channel's queue. - - Args: - command: The command to send - """ - self._queue.put(command) diff --git a/api/graphon/graph_engine/command_channels/redis_channel.py b/api/graphon/graph_engine/command_channels/redis_channel.py deleted file mode 100644 index 77cf884c67..0000000000 --- a/api/graphon/graph_engine/command_channels/redis_channel.py +++ /dev/null @@ -1,153 +0,0 @@ -""" -Redis-based implementation of CommandChannel for distributed scenarios. - -This implementation uses Redis lists for command queuing, supporting -multi-instance deployments and cross-server communication. -Each instance uses a unique key for its command queue. -""" - -import json -from contextlib import AbstractContextManager -from typing import Any, Protocol, final - -from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand - - -class RedisPipelineProtocol(Protocol): - """Minimal Redis pipeline contract used by the command channel.""" - - def lrange(self, name: str, start: int, end: int) -> Any: ... - def delete(self, *names: str) -> Any: ... - def execute(self) -> list[Any]: ... - def rpush(self, name: str, *values: str) -> Any: ... - def expire(self, name: str, time: int) -> Any: ... - def set(self, name: str, value: str, ex: int | None = None) -> Any: ... - def get(self, name: str) -> Any: ... - - -class RedisClientProtocol(Protocol): - """Redis client contract required by the command channel.""" - - def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ... - - -@final -class RedisChannel: - """ - Redis-based command channel implementation for distributed systems. - - Each instance uses a unique Redis key for its command queue. - Commands are JSON-serialized for transport. - """ - - def __init__( - self, - redis_client: RedisClientProtocol, - channel_key: str, - command_ttl: int = 3600, - ) -> None: - """ - Initialize the Redis channel. - - Args: - redis_client: Redis client instance - channel_key: Unique key for this channel's command queue - command_ttl: TTL for command keys in seconds (default: 3600) - """ - self._redis = redis_client - self._key = channel_key - self._command_ttl = command_ttl - self._pending_key = f"{channel_key}:pending" - - def fetch_commands(self) -> list[GraphEngineCommand]: - """ - Fetch all pending commands from Redis. - - Returns: - List of pending commands (drains the Redis list) - """ - if not self._has_pending_commands(): - return [] - - commands: list[GraphEngineCommand] = [] - - # Use pipeline for atomic operations - with self._redis.pipeline() as pipe: - # Get all commands and clear the list atomically - pipe.lrange(self._key, 0, -1) - pipe.delete(self._key) - results = pipe.execute() - - # Parse commands from JSON - if results[0]: - for command_json in results[0]: - try: - command_data = json.loads(command_json) - command = self._deserialize_command(command_data) - if command: - commands.append(command) - except (json.JSONDecodeError, ValueError): - # Skip invalid commands - continue - - return commands - - def send_command(self, command: GraphEngineCommand) -> None: - """ - Send a command to Redis. - - Args: - command: The command to send - """ - command_json = json.dumps(command.model_dump()) - - # Push to list and set expiry - with self._redis.pipeline() as pipe: - pipe.rpush(self._key, command_json) - pipe.expire(self._key, self._command_ttl) - pipe.set(self._pending_key, "1", ex=self._command_ttl) - pipe.execute() - - def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None: - """ - Deserialize a command from dictionary data. - - Args: - data: Command data dictionary - - Returns: - Deserialized command or None if invalid - """ - command_type_value = data.get("command_type") - if not isinstance(command_type_value, str): - return None - - try: - command_type = CommandType(command_type_value) - - if command_type == CommandType.ABORT: - return AbortCommand.model_validate(data) - if command_type == CommandType.PAUSE: - return PauseCommand.model_validate(data) - if command_type == CommandType.UPDATE_VARIABLES: - return UpdateVariablesCommand.model_validate(data) - - # For other command types, use base class - return GraphEngineCommand.model_validate(data) - - except (ValueError, TypeError): - return None - - def _has_pending_commands(self) -> bool: - """ - Check and consume the pending marker to avoid unnecessary list reads. - - Returns: - True if commands should be fetched from Redis. - """ - with self._redis.pipeline() as pipe: - pipe.get(self._pending_key) - pipe.delete(self._pending_key) - pending_value, _ = pipe.execute() - - return pending_value is not None diff --git a/api/graphon/graph_engine/command_processing/__init__.py b/api/graphon/graph_engine/command_processing/__init__.py deleted file mode 100644 index 7b4f0dfff7..0000000000 --- a/api/graphon/graph_engine/command_processing/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Command processing subsystem for graph engine. - -This package handles external commands sent to the engine -during execution. -""" - -from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler -from .command_processor import CommandProcessor - -__all__ = [ - "AbortCommandHandler", - "CommandProcessor", - "PauseCommandHandler", - "UpdateVariablesCommandHandler", -] diff --git a/api/graphon/graph_engine/command_processing/command_handlers.py b/api/graphon/graph_engine/command_processing/command_handlers.py deleted file mode 100644 index ad92fd1abb..0000000000 --- a/api/graphon/graph_engine/command_processing/command_handlers.py +++ /dev/null @@ -1,56 +0,0 @@ -import logging -from typing import final - -from typing_extensions import override - -from graphon.entities.pause_reason import SchedulingPause -from graphon.runtime import VariablePool - -from ..domain.graph_execution import GraphExecution -from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand -from .command_processor import CommandHandler - -logger = logging.getLogger(__name__) - - -@final -class AbortCommandHandler(CommandHandler): - @override - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - assert isinstance(command, AbortCommand) - logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason) - execution.abort(command.reason or "User requested abort") - - -@final -class PauseCommandHandler(CommandHandler): - @override - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - assert isinstance(command, PauseCommand) - logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason) - # Convert string reason to PauseReason if needed - reason = command.reason - pause_reason = SchedulingPause(message=reason) - execution.pause(pause_reason) - - -@final -class UpdateVariablesCommandHandler(CommandHandler): - def __init__(self, variable_pool: VariablePool) -> None: - self._variable_pool = variable_pool - - @override - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - assert isinstance(command, UpdateVariablesCommand) - for update in command.updates: - try: - variable = update.value - self._variable_pool.add(variable.selector, variable) - logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id) - except ValueError as exc: - logger.warning( - "Skipping invalid variable selector %s for workflow %s: %s", - getattr(update.value, "selector", None), - execution.workflow_id, - exc, - ) diff --git a/api/graphon/graph_engine/command_processing/command_processor.py b/api/graphon/graph_engine/command_processing/command_processor.py deleted file mode 100644 index 942c2d77a5..0000000000 --- a/api/graphon/graph_engine/command_processing/command_processor.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Main command processor for handling external commands. -""" - -import logging -from typing import Protocol, final - -from ..domain.graph_execution import GraphExecution -from ..entities.commands import GraphEngineCommand -from ..protocols.command_channel import CommandChannel - -logger = logging.getLogger(__name__) - - -class CommandHandler(Protocol): - """Protocol for command handlers.""" - - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ... - - -@final -class CommandProcessor: - """ - Processes external commands sent to the engine. - - This polls the command channel and dispatches commands to - appropriate handlers. - """ - - def __init__( - self, - command_channel: CommandChannel, - graph_execution: GraphExecution, - ) -> None: - """ - Initialize the command processor. - - Args: - command_channel: Channel for receiving commands - graph_execution: Graph execution aggregate - """ - self._command_channel = command_channel - self._graph_execution = graph_execution - self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {} - - def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None: - """ - Register a handler for a command type. - - Args: - command_type: Type of command to handle - handler: Handler for the command - """ - self._handlers[command_type] = handler - - def process_commands(self) -> None: - """Check for and process any pending commands.""" - try: - commands = self._command_channel.fetch_commands() - for command in commands: - self._handle_command(command) - except Exception as e: - logger.warning("Error processing commands: %s", e) - - def _handle_command(self, command: GraphEngineCommand) -> None: - """ - Handle a single command. - - Args: - command: The command to handle - """ - handler = self._handlers.get(type(command)) - if handler: - try: - handler.handle(command, self._graph_execution) - except Exception: - logger.exception("Error handling command %s", command.__class__.__name__) - else: - logger.warning("No handler registered for command: %s", command.__class__.__name__) diff --git a/api/graphon/graph_engine/config.py b/api/graphon/graph_engine/config.py deleted file mode 100644 index d56a69cee0..0000000000 --- a/api/graphon/graph_engine/config.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -GraphEngine configuration models. -""" - -from pydantic import BaseModel, ConfigDict - - -class GraphEngineConfig(BaseModel): - """Configuration for GraphEngine worker pool scaling.""" - - model_config = ConfigDict(frozen=True) - - min_workers: int = 1 - max_workers: int = 5 - scale_up_threshold: int = 3 - scale_down_idle_time: float = 5.0 diff --git a/api/graphon/graph_engine/domain/__init__.py b/api/graphon/graph_engine/domain/__init__.py deleted file mode 100644 index 9e9afe4c21..0000000000 --- a/api/graphon/graph_engine/domain/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Domain models for graph engine. - -This package contains the core domain entities, value objects, and aggregates -that represent the business concepts of workflow graph execution. -""" - -from .graph_execution import GraphExecution -from .node_execution import NodeExecution - -__all__ = [ - "GraphExecution", - "NodeExecution", -] diff --git a/api/graphon/graph_engine/domain/graph_execution.py b/api/graphon/graph_engine/domain/graph_execution.py deleted file mode 100644 index 9c0c7d1624..0000000000 --- a/api/graphon/graph_engine/domain/graph_execution.py +++ /dev/null @@ -1,242 +0,0 @@ -"""GraphExecution aggregate root managing the overall graph execution state.""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from importlib import import_module -from typing import Literal - -from pydantic import BaseModel, Field - -from graphon.entities.pause_reason import PauseReason -from graphon.enums import NodeState -from graphon.runtime.graph_runtime_state import GraphExecutionProtocol - -from .node_execution import NodeExecution - - -class GraphExecutionErrorState(BaseModel): - """Serializable representation of an execution error.""" - - module: str = Field(description="Module containing the exception class") - qualname: str = Field(description="Qualified name of the exception class") - message: str | None = Field(default=None, description="Exception message string") - - -class NodeExecutionState(BaseModel): - """Serializable representation of a node execution entity.""" - - node_id: str - state: NodeState = Field(default=NodeState.UNKNOWN) - retry_count: int = Field(default=0) - execution_id: str | None = Field(default=None) - error: str | None = Field(default=None) - - -class GraphExecutionState(BaseModel): - """Pydantic model describing serialized GraphExecution state.""" - - type: Literal["GraphExecution"] = Field(default="GraphExecution") - version: str = Field(default="1.0") - workflow_id: str - started: bool = Field(default=False) - completed: bool = Field(default=False) - aborted: bool = Field(default=False) - paused: bool = Field(default=False) - pause_reasons: list[PauseReason] = Field(default_factory=list) - error: GraphExecutionErrorState | None = Field(default=None) - exceptions_count: int = Field(default=0) - node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState]) - - -def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None: - """Convert an exception into its serializable representation.""" - - if error is None: - return None - - return GraphExecutionErrorState( - module=error.__class__.__module__, - qualname=error.__class__.__qualname__, - message=str(error), - ) - - -def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]: - """Locate an exception class from its module and qualified name.""" - - module = import_module(module_name) - attr: object = module - for part in qualname.split("."): - attr = getattr(attr, part) - - if isinstance(attr, type) and issubclass(attr, Exception): - return attr - - raise TypeError(f"{qualname} in {module_name} is not an Exception subclass") - - -def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None: - """Reconstruct an exception instance from serialized data.""" - - if state is None: - return None - - try: - exception_class = _resolve_exception_class(state.module, state.qualname) - if state.message is None: - return exception_class() - return exception_class(state.message) - except Exception: - # Fallback to RuntimeError when reconstruction fails - if state.message is None: - return RuntimeError(state.qualname) - return RuntimeError(state.message) - - -@dataclass -class GraphExecution: - """ - Aggregate root for graph execution. - - This manages the overall execution state of a workflow graph, - coordinating between multiple node executions. - """ - - workflow_id: str - started: bool = False - completed: bool = False - aborted: bool = False - paused: bool = False - pause_reasons: list[PauseReason] = field(default_factory=list) - error: Exception | None = None - node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution]) - exceptions_count: int = 0 - - def start(self) -> None: - """Mark the graph execution as started.""" - if self.started: - raise RuntimeError("Graph execution already started") - self.started = True - - def complete(self) -> None: - """Mark the graph execution as completed.""" - if not self.started: - raise RuntimeError("Cannot complete execution that hasn't started") - if self.completed: - raise RuntimeError("Graph execution already completed") - self.completed = True - - def abort(self, reason: str) -> None: - """Abort the graph execution.""" - self.aborted = True - self.error = RuntimeError(f"Aborted: {reason}") - - def pause(self, reason: PauseReason) -> None: - """Pause the graph execution without marking it complete.""" - if self.completed: - raise RuntimeError("Cannot pause execution that has completed") - if self.aborted: - raise RuntimeError("Cannot pause execution that has been aborted") - self.paused = True - self.pause_reasons.append(reason) - - def fail(self, error: Exception) -> None: - """Mark the graph execution as failed.""" - self.error = error - self.completed = True - - def get_or_create_node_execution(self, node_id: str) -> NodeExecution: - """Get or create a node execution entity.""" - if node_id not in self.node_executions: - self.node_executions[node_id] = NodeExecution(node_id=node_id) - return self.node_executions[node_id] - - @property - def is_running(self) -> bool: - """Check if the execution is currently running.""" - return self.started and not self.completed and not self.aborted and not self.paused - - @property - def is_paused(self) -> bool: - """Check if the execution is currently paused.""" - return self.paused - - @property - def has_error(self) -> bool: - """Check if the execution has encountered an error.""" - return self.error is not None - - @property - def error_message(self) -> str | None: - """Get the error message if an error exists.""" - if not self.error: - return None - return str(self.error) - - def dumps(self) -> str: - """Serialize the aggregate state into a JSON string.""" - - node_states = [ - NodeExecutionState( - node_id=node_id, - state=node_execution.state, - retry_count=node_execution.retry_count, - execution_id=node_execution.execution_id, - error=node_execution.error, - ) - for node_id, node_execution in sorted(self.node_executions.items()) - ] - - state = GraphExecutionState( - workflow_id=self.workflow_id, - started=self.started, - completed=self.completed, - aborted=self.aborted, - paused=self.paused, - pause_reasons=self.pause_reasons, - error=_serialize_error(self.error), - exceptions_count=self.exceptions_count, - node_executions=node_states, - ) - - return state.model_dump_json() - - def loads(self, data: str) -> None: - """Restore aggregate state from a serialized JSON string.""" - - state = GraphExecutionState.model_validate_json(data) - - if state.type != "GraphExecution": - raise ValueError(f"Invalid serialized data type: {state.type}") - - if state.version != "1.0": - raise ValueError(f"Unsupported serialized version: {state.version}") - - if self.workflow_id != state.workflow_id: - raise ValueError("Serialized workflow_id does not match aggregate identity") - - self.started = state.started - self.completed = state.completed - self.aborted = state.aborted - self.paused = state.paused - self.pause_reasons = state.pause_reasons - self.error = _deserialize_error(state.error) - self.exceptions_count = state.exceptions_count - self.node_executions = { - item.node_id: NodeExecution( - node_id=item.node_id, - state=item.state, - retry_count=item.retry_count, - execution_id=item.execution_id, - error=item.error, - ) - for item in state.node_executions - } - - def record_node_failure(self) -> None: - """Increment the count of node failures encountered during execution.""" - self.exceptions_count += 1 - - -_: GraphExecutionProtocol = GraphExecution(workflow_id="") diff --git a/api/graphon/graph_engine/domain/node_execution.py b/api/graphon/graph_engine/domain/node_execution.py deleted file mode 100644 index dafd6ccd8a..0000000000 --- a/api/graphon/graph_engine/domain/node_execution.py +++ /dev/null @@ -1,45 +0,0 @@ -""" -NodeExecution entity representing a node's execution state. -""" - -from dataclasses import dataclass - -from graphon.enums import NodeState - - -@dataclass -class NodeExecution: - """ - Entity representing the execution state of a single node. - - This is a mutable entity that tracks the runtime state of a node - during graph execution. - """ - - node_id: str - state: NodeState = NodeState.UNKNOWN - retry_count: int = 0 - execution_id: str | None = None - error: str | None = None - - def mark_started(self, execution_id: str) -> None: - """Mark the node as started with an execution ID.""" - self.state = NodeState.TAKEN - self.execution_id = execution_id - - def mark_taken(self) -> None: - """Mark the node as successfully completed.""" - self.state = NodeState.TAKEN - self.error = None - - def mark_failed(self, error: str) -> None: - """Mark the node as failed with an error.""" - self.error = error - - def mark_skipped(self) -> None: - """Mark the node as skipped.""" - self.state = NodeState.SKIPPED - - def increment_retry(self) -> None: - """Increment the retry count for this node.""" - self.retry_count += 1 diff --git a/api/graphon/graph_engine/entities/commands.py b/api/graphon/graph_engine/entities/commands.py deleted file mode 100644 index 25ebc804b6..0000000000 --- a/api/graphon/graph_engine/entities/commands.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -GraphEngine command entities for external control. - -This module defines command types that can be sent to a running GraphEngine -instance to control its execution flow. -""" - -from collections.abc import Sequence -from enum import StrEnum, auto -from typing import Any - -from pydantic import BaseModel, Field - -from graphon.variables.variables import Variable - - -class CommandType(StrEnum): - """Types of commands that can be sent to GraphEngine.""" - - ABORT = auto() - PAUSE = auto() - UPDATE_VARIABLES = auto() - - -class GraphEngineCommand(BaseModel): - """Base class for all GraphEngine commands.""" - - command_type: CommandType = Field(..., description="Type of command") - payload: dict[str, Any] | None = Field(default=None, description="Optional command payload") - - -class AbortCommand(GraphEngineCommand): - """Command to abort a running workflow execution.""" - - command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command") - reason: str | None = Field(default=None, description="Optional reason for abort") - - -class PauseCommand(GraphEngineCommand): - """Command to pause a running workflow execution.""" - - command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command") - reason: str = Field(default="unknown reason", description="reason for pause") - - -class VariableUpdate(BaseModel): - """Represents a single variable update instruction.""" - - value: Variable = Field(description="New variable value") - - -class UpdateVariablesCommand(GraphEngineCommand): - """Command to update a group of variables in the variable pool.""" - - command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command") - updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates") diff --git a/api/graphon/graph_engine/error_handler.py b/api/graphon/graph_engine/error_handler.py deleted file mode 100644 index 43ce8bb502..0000000000 --- a/api/graphon/graph_engine/error_handler.py +++ /dev/null @@ -1,213 +0,0 @@ -""" -Main error handler that coordinates error strategies. -""" - -import logging -import time -from typing import TYPE_CHECKING, final - -from graphon.enums import ( - ErrorStrategy as ErrorStrategyEnum, -) -from graphon.enums import ( - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.graph import Graph -from graphon.graph_events import ( - GraphNodeEventBase, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunRetryEvent, -) -from graphon.node_events import NodeRunResult - -if TYPE_CHECKING: - from .domain import GraphExecution - -logger = logging.getLogger(__name__) - - -@final -class ErrorHandler: - """ - Coordinates error handling strategies for node failures. - - This acts as a facade for the various error strategies, - selecting and applying the appropriate strategy based on - node configuration. - """ - - def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None: - """ - Initialize the error handler. - - Args: - graph: The workflow graph - graph_execution: The graph execution state - """ - self._graph = graph - self._graph_execution = graph_execution - - def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None: - """ - Handle a node failure event. - - Selects and applies the appropriate error strategy based on - the node's configuration. - - Args: - event: The node failure event - - Returns: - Optional new event to process, or None to abort - """ - node = self._graph.nodes[event.node_id] - # Get retry count from NodeExecution - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - retry_count = node_execution.retry_count - - # First check if retry is configured and not exhausted - if node.retry and retry_count < node.retry_config.max_retries: - result = self._handle_retry(event, retry_count) - if result: - # Retry count will be incremented when NodeRunRetryEvent is handled - return result - - # Apply configured error strategy - strategy = node.error_strategy - - match strategy: - case None: - return self._handle_abort(event) - case ErrorStrategyEnum.FAIL_BRANCH: - return self._handle_fail_branch(event) - case ErrorStrategyEnum.DEFAULT_VALUE: - return self._handle_default_value(event) - - def _handle_abort(self, event: NodeRunFailedEvent): - """ - Handle error by aborting execution. - - This is the default strategy when no other strategy is specified. - It stops the entire graph execution when a node fails. - - Args: - event: The failure event - - Returns: - None - signals abortion - """ - logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error) - # Return None to signal that execution should stop - - def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int): - """ - Handle error by retrying the node. - - This strategy re-attempts node execution up to a configured - maximum number of retries with configurable intervals. - - Args: - event: The failure event - retry_count: Current retry attempt count - - Returns: - NodeRunRetryEvent if retry should occur, None otherwise - """ - node = self._graph.nodes[event.node_id] - - # Check if we've exceeded max retries - if not node.retry or retry_count >= node.retry_config.max_retries: - return None - - # Wait for retry interval - time.sleep(node.retry_config.retry_interval_seconds) - - # Create retry event - return NodeRunRetryEvent( - id=event.id, - node_title=node.title, - node_id=event.node_id, - node_type=event.node_type, - node_run_result=event.node_run_result, - start_at=event.start_at, - error=event.error, - retry_index=retry_count + 1, - ) - - def _handle_fail_branch(self, event: NodeRunFailedEvent): - """ - Handle error by taking the fail branch. - - This strategy converts failures to exceptions and routes execution - through a designated fail-branch edge. - - Args: - event: The failure event - - Returns: - NodeRunExceptionEvent to continue via fail branch - """ - outputs = { - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - finished_at=event.finished_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - edge_source_handle="fail-branch", - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH, - }, - ), - error=event.error, - ) - - def _handle_default_value(self, event: NodeRunFailedEvent): - """ - Handle error by using default values. - - This strategy allows nodes to fail gracefully by providing - predefined default output values. - - Args: - event: The failure event - - Returns: - NodeRunExceptionEvent with default values - """ - node = self._graph.nodes[event.node_id] - - outputs = { - **node.default_value_dict, - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - finished_at=event.finished_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE, - }, - ), - error=event.error, - ) diff --git a/api/graphon/graph_engine/event_management/__init__.py b/api/graphon/graph_engine/event_management/__init__.py deleted file mode 100644 index f6c3c0f753..0000000000 --- a/api/graphon/graph_engine/event_management/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Event management subsystem for graph engine. - -This package handles event routing, collection, and emission for -workflow graph execution events. -""" - -from .event_handlers import EventHandler -from .event_manager import EventManager - -__all__ = [ - "EventHandler", - "EventManager", -] diff --git a/api/graphon/graph_engine/event_management/event_handlers.py b/api/graphon/graph_engine/event_management/event_handlers.py deleted file mode 100644 index 184148280d..0000000000 --- a/api/graphon/graph_engine/event_management/event_handlers.py +++ /dev/null @@ -1,367 +0,0 @@ -""" -Event handler implementations for different event types. -""" - -import logging -from collections.abc import Mapping -from functools import singledispatchmethod -from typing import TYPE_CHECKING, final - -from graphon.enums import ErrorStrategy, NodeExecutionType, NodeState -from graphon.graph import Graph -from graphon.graph_events import ( - GraphNodeEventBase, - NodeRunAgentLogEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState - -from ..domain.graph_execution import GraphExecution -from ..response_coordinator import ResponseStreamCoordinator - -if TYPE_CHECKING: - from ..error_handler import ErrorHandler - from ..graph_state_manager import GraphStateManager - from ..graph_traversal import EdgeProcessor - from .event_manager import EventManager - -logger = logging.getLogger(__name__) - - -@final -class EventHandler: - """ - Registry of event handlers for different event types. - - This centralizes the business logic for handling specific events, - keeping it separate from the routing and collection infrastructure. - """ - - def __init__( - self, - graph: Graph, - graph_runtime_state: GraphRuntimeState, - graph_execution: GraphExecution, - response_coordinator: ResponseStreamCoordinator, - event_collector: "EventManager", - edge_processor: "EdgeProcessor", - state_manager: "GraphStateManager", - error_handler: "ErrorHandler", - ) -> None: - """ - Initialize the event handler registry. - - Args: - graph: The workflow graph - graph_runtime_state: Runtime state with variable pool - graph_execution: Graph execution aggregate - response_coordinator: Response stream coordinator - event_collector: Event manager for collecting events - edge_processor: Edge processor for edge traversal - state_manager: Unified state manager - error_handler: Error handler - """ - self._graph = graph - self._graph_runtime_state = graph_runtime_state - self._graph_execution = graph_execution - self._response_coordinator = response_coordinator - self._event_collector = event_collector - self._edge_processor = edge_processor - self._state_manager = state_manager - self._error_handler = error_handler - - def dispatch(self, event: GraphNodeEventBase) -> None: - """ - Handle any node event by dispatching to the appropriate handler. - - Args: - event: The event to handle - """ - if isinstance(event, NodeRunVariableUpdatedEvent): - self._dispatch(event) - return - - # Events in loops or iterations are always collected - if event.in_loop_id or event.in_iteration_id: - self._event_collector.collect(event) - return - return self._dispatch(event) - - @singledispatchmethod - def _dispatch(self, event: GraphNodeEventBase) -> None: - self._event_collector.collect(event) - logger.warning("Unhandled event type: %s", type(event).__name__) - - @_dispatch.register(NodeRunIterationStartedEvent) - @_dispatch.register(NodeRunIterationNextEvent) - @_dispatch.register(NodeRunIterationSucceededEvent) - @_dispatch.register(NodeRunIterationFailedEvent) - @_dispatch.register(NodeRunLoopStartedEvent) - @_dispatch.register(NodeRunLoopNextEvent) - @_dispatch.register(NodeRunLoopSucceededEvent) - @_dispatch.register(NodeRunLoopFailedEvent) - @_dispatch.register(NodeRunAgentLogEvent) - @_dispatch.register(NodeRunRetrieverResourceEvent) - def _(self, event: GraphNodeEventBase) -> None: - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunStartedEvent) -> None: - """ - Handle node started event. - - Args: - event: The node started event - """ - # Track execution in domain model - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - is_initial_attempt = node_execution.retry_count == 0 - node_execution.mark_started(event.id) - self._graph_runtime_state.increment_node_run_steps() - - # Track in response coordinator for stream ordering - self._response_coordinator.track_node_execution(event.node_id, event.id) - - # Collect the event only for the first attempt; retries remain silent - if is_initial_attempt: - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunStreamChunkEvent) -> None: - """ - Handle stream chunk event with full processing. - - Args: - event: The stream chunk event - """ - # Process with response coordinator - streaming_events = list(self._response_coordinator.intercept_event(event)) - - # Collect all events - for stream_event in streaming_events: - self._event_collector.collect(stream_event) - - @_dispatch.register - def _(self, event: NodeRunVariableUpdatedEvent) -> None: - """ - Apply a node-requested variable mutation before downstream observers run. - - The event is collected like other node events so parent/container engines can - forward the updated payload to outer layers, including persistence listeners. - """ - self._graph_runtime_state.variable_pool.add(event.variable.selector, event.variable) - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunSucceededEvent) -> None: - """ - Handle node success by coordinating subsystems. - - This method coordinates between different subsystems to process - node completion, handle edges, and trigger downstream execution. - - Args: - event: The node succeeded event - """ - # Update domain model - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.mark_taken() - - self._accumulate_node_usage(event.node_run_result.llm_usage) - - # Store outputs in variable pool - self._store_node_outputs(event.node_id, event.node_run_result.outputs) - - # Forward to response coordinator and emit streaming events - streaming_events = self._response_coordinator.intercept_event(event) - for stream_event in streaming_events: - self._event_collector.collect(stream_event) - - # Process edges and get ready nodes - node = self._graph.nodes[event.node_id] - if node.execution_type == NodeExecutionType.BRANCH: - ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion( - event.node_id, event.node_run_result.edge_source_handle - ) - else: - ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id) - - # Collect streaming events from edge processing - for edge_event in edge_streaming_events: - self._event_collector.collect(edge_event) - - # Enqueue ready nodes - if self._graph_execution.is_paused: - for node_id in ready_nodes: - self._graph_runtime_state.register_deferred_node(node_id) - else: - for node_id in ready_nodes: - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) - - # Update execution tracking - self._state_manager.finish_execution(event.node_id) - - # Handle response node outputs - if node.execution_type == NodeExecutionType.RESPONSE: - self._update_response_outputs(event.node_run_result.outputs) - - # Collect the event - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunPauseRequestedEvent) -> None: - """Handle pause requests emitted by nodes.""" - - pause_reason = event.reason - self._graph_execution.pause(pause_reason) - self._state_manager.finish_execution(event.node_id) - if event.node_id in self._graph.nodes: - self._graph.nodes[event.node_id].state = NodeState.UNKNOWN - self._graph_runtime_state.register_paused_node(event.node_id) - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunFailedEvent) -> None: - """ - Handle node failure using error handler. - - Args: - event: The node failed event - """ - # Update domain model - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.mark_failed(event.error) - self._graph_execution.record_node_failure() - - self._accumulate_node_usage(event.node_run_result.llm_usage) - - result = self._error_handler.handle_node_failure(event) - - if result: - # Process the resulting event (retry, exception, etc.) - self.dispatch(result) - else: - # Abort execution - self._graph_execution.fail(RuntimeError(event.error)) - self._event_collector.collect(event) - self._state_manager.finish_execution(event.node_id) - - @_dispatch.register - def _(self, event: NodeRunExceptionEvent) -> None: - """ - Handle node exception event (fail-branch strategy). - - Args: - event: The node exception event - """ - # Node continues via fail-branch/default-value, treat as completion - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.mark_taken() - - self._accumulate_node_usage(event.node_run_result.llm_usage) - - # Persist outputs produced by the exception strategy (e.g. default values) - self._store_node_outputs(event.node_id, event.node_run_result.outputs) - - node = self._graph.nodes[event.node_id] - - if node.error_strategy == ErrorStrategy.DEFAULT_VALUE: - ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id) - elif node.error_strategy == ErrorStrategy.FAIL_BRANCH: - ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion( - event.node_id, event.node_run_result.edge_source_handle - ) - else: - raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}") - - for edge_event in edge_streaming_events: - self._event_collector.collect(edge_event) - - for node_id in ready_nodes: - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) - - # Update response outputs if applicable - if node.execution_type == NodeExecutionType.RESPONSE: - self._update_response_outputs(event.node_run_result.outputs) - - self._state_manager.finish_execution(event.node_id) - - # Collect the exception event for observers - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunRetryEvent) -> None: - """ - Handle node retry event. - - Args: - event: The node retry event - """ - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.increment_retry() - - # Finish the previous attempt before re-queuing the node - self._state_manager.finish_execution(event.node_id) - - # Emit retry event for observers - self._event_collector.collect(event) - - # Re-queue node for execution - self._state_manager.enqueue_node(event.node_id) - self._state_manager.start_execution(event.node_id) - - def _accumulate_node_usage(self, usage: LLMUsage) -> None: - """Accumulate token usage into the shared runtime state.""" - if usage.total_tokens <= 0: - return - - self._graph_runtime_state.add_tokens(usage.total_tokens) - - current_usage = self._graph_runtime_state.llm_usage - if current_usage.total_tokens == 0: - self._graph_runtime_state.llm_usage = usage - else: - self._graph_runtime_state.llm_usage = current_usage.plus(usage) - - def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None: - """ - Store node outputs in the variable pool. - - Args: - event: The node succeeded event containing outputs - """ - for variable_name, variable_value in outputs.items(): - self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value) - - def _update_response_outputs(self, outputs: Mapping[str, object]) -> None: - """Update response outputs for response nodes.""" - # TODO: Design a mechanism for nodes to notify the engine about how to update outputs - # in runtime state, rather than allowing nodes to directly access runtime state. - for key, value in outputs.items(): - if key == "answer": - existing = self._graph_runtime_state.get_output("answer", "") - if existing: - self._graph_runtime_state.set_output("answer", f"{existing}{value}") - else: - self._graph_runtime_state.set_output("answer", value) - else: - self._graph_runtime_state.set_output(key, value) diff --git a/api/graphon/graph_engine/event_management/event_manager.py b/api/graphon/graph_engine/event_management/event_manager.py deleted file mode 100644 index 5b2fb365e9..0000000000 --- a/api/graphon/graph_engine/event_management/event_manager.py +++ /dev/null @@ -1,186 +0,0 @@ -""" -Unified event manager for collecting and emitting events. -""" - -import logging -import threading -import time -from collections.abc import Generator -from contextlib import contextmanager -from typing import final - -from graphon.graph_events import GraphEngineEvent - -from ..layers.base import GraphEngineLayer - -_logger = logging.getLogger(__name__) - - -@final -class ReadWriteLock: - """ - A read-write lock implementation that allows multiple concurrent readers - but only one writer at a time. - """ - - def __init__(self) -> None: - self._read_ready = threading.Condition(threading.RLock()) - self._readers = 0 - - def acquire_read(self) -> None: - """Acquire a read lock.""" - _ = self._read_ready.acquire() - try: - self._readers += 1 - finally: - self._read_ready.release() - - def release_read(self) -> None: - """Release a read lock.""" - _ = self._read_ready.acquire() - try: - self._readers -= 1 - if self._readers == 0: - self._read_ready.notify_all() - finally: - self._read_ready.release() - - def acquire_write(self) -> None: - """Acquire a write lock.""" - _ = self._read_ready.acquire() - while self._readers > 0: - _ = self._read_ready.wait() - - def release_write(self) -> None: - """Release a write lock.""" - self._read_ready.release() - - @contextmanager - def read_lock(self): - """Return a context manager for read locking.""" - self.acquire_read() - try: - yield - finally: - self.release_read() - - @contextmanager - def write_lock(self): - """Return a context manager for write locking.""" - self.acquire_write() - try: - yield - finally: - self.release_write() - - -@final -class EventManager: - """ - Unified event manager that collects, buffers, and emits events. - - This class combines event collection with event emission, providing - thread-safe event management with support for notifying layers and - streaming events to external consumers. - """ - - def __init__(self) -> None: - """Initialize the event manager.""" - self._events: list[GraphEngineEvent] = [] - self._lock = ReadWriteLock() - self._layers: list[GraphEngineLayer] = [] - self._execution_complete = threading.Event() - - def set_layers(self, layers: list[GraphEngineLayer]) -> None: - """ - Set the layers to notify on event collection. - - Args: - layers: List of layers to notify - """ - self._layers = layers - - def notify_layers(self, event: GraphEngineEvent) -> None: - """Notify registered layers about an event without buffering it.""" - self._notify_layers(event) - - def collect(self, event: GraphEngineEvent) -> None: - """ - Thread-safe method to collect an event. - - Args: - event: The event to collect - """ - with self._lock.write_lock(): - self._events.append(event) - - # NOTE: `_notify_layers` is intentionally called outside the critical section - # to minimize lock contention and avoid blocking other readers or writers. - # - # The public `notify_layers` method also does not use a write lock, - # so protecting `_notify_layers` with a lock here is unnecessary. - self._notify_layers(event) - - def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]: - """ - Get new events starting from a specific index. - - Args: - start_index: The index to start from - - Returns: - List of new events - """ - with self._lock.read_lock(): - return list(self._events[start_index:]) - - def _event_count(self) -> int: - """ - Get the current count of collected events. - - Returns: - Number of collected events - """ - with self._lock.read_lock(): - return len(self._events) - - def mark_complete(self) -> None: - """Mark execution as complete to stop the event emission generator.""" - self._execution_complete.set() - - def emit_events(self) -> Generator[GraphEngineEvent, None, None]: - """ - Generator that yields events as they're collected. - - Yields: - GraphEngineEvent instances as they're processed - """ - yielded_count = 0 - - while not self._execution_complete.is_set() or yielded_count < self._event_count(): - # Get new events since last yield - new_events = self._get_new_events(yielded_count) - - # Yield any new events - for event in new_events: - yield event - yielded_count += 1 - - # Small sleep to avoid busy waiting - if not self._execution_complete.is_set() and not new_events: - time.sleep(0.001) - - def _notify_layers(self, event: GraphEngineEvent) -> None: - """ - Notify all layers of an event. - - Layer exceptions are caught and logged to prevent disrupting collection. - - Args: - event: The event to send to layers - """ - for layer in self._layers: - try: - layer.on_event(event) - except Exception: - _logger.exception("Error in layer on_event, layer_type=%s", type(layer)) diff --git a/api/graphon/graph_engine/graph_engine.py b/api/graphon/graph_engine/graph_engine.py deleted file mode 100644 index 32e0e60502..0000000000 --- a/api/graphon/graph_engine/graph_engine.py +++ /dev/null @@ -1,377 +0,0 @@ -""" -QueueBasedGraphEngine - Main orchestrator for queue-based workflow execution. - -This engine uses a modular architecture with separated packages following -Domain-Driven Design principles for improved maintainability and testability. -""" - -from __future__ import annotations - -import logging -import queue -from collections.abc import Generator -from typing import TYPE_CHECKING, cast, final - -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import NodeExecutionType -from graphon.graph import Graph -from graphon.graph_events import ( - GraphEngineEvent, - GraphNodeEventBase, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) -from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from graphon.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol - -if TYPE_CHECKING: # pragma: no cover - used only for static analysis - from graphon.runtime.graph_runtime_state import GraphProtocol - -from .command_processing import ( - AbortCommandHandler, - CommandProcessor, - PauseCommandHandler, - UpdateVariablesCommandHandler, -) -from .config import GraphEngineConfig -from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand -from .error_handler import ErrorHandler -from .event_management import EventHandler, EventManager -from .graph_state_manager import GraphStateManager -from .graph_traversal import EdgeProcessor, SkipPropagator -from .layers.base import GraphEngineLayer -from .orchestration import Dispatcher, ExecutionCoordinator -from .protocols.command_channel import CommandChannel -from .worker_management import WorkerPool - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.graph_engine.domain.graph_execution import GraphExecution - from graphon.graph_engine.response_coordinator import ResponseStreamCoordinator - -logger = logging.getLogger(__name__) - - -_DEFAULT_CONFIG = GraphEngineConfig() - - -@final -class GraphEngine: - """ - Queue-based graph execution engine. - - Uses a modular architecture that delegates responsibilities to specialized - subsystems, following Domain-Driven Design and SOLID principles. - """ - - def __init__( - self, - workflow_id: str, - graph: Graph, - graph_runtime_state: GraphRuntimeState, - command_channel: CommandChannel, - config: GraphEngineConfig = _DEFAULT_CONFIG, - child_engine_builder: ChildGraphEngineBuilderProtocol | None = None, - ) -> None: - """Initialize the graph engine with all subsystems and dependencies.""" - - # Bind runtime state to current workflow context - self._graph = graph - self._graph_runtime_state = graph_runtime_state - self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) - self._command_channel = command_channel - self._config = config - self._layers: list[GraphEngineLayer] = [] - self._child_engine_builder = child_engine_builder - if child_engine_builder is not None: - self._graph_runtime_state.bind_child_engine_builder(child_engine_builder) - - # Graph execution tracks the overall execution state - self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution) - self._graph_execution.workflow_id = workflow_id - - # === Execution Queues === - self._ready_queue = self._graph_runtime_state.ready_queue - - # Queue for events generated during execution - self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() - - # === State Management === - # Unified state manager handles all node state transitions and queue operations - self._state_manager = GraphStateManager(self._graph, self._ready_queue) - - # === Response Coordination === - # Coordinates response streaming from response nodes - self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator) - - # === Event Management === - # Event manager handles both collection and emission of events - self._event_manager = EventManager() - - # === Error Handling === - # Centralized error handler for graph execution errors - self._error_handler = ErrorHandler(self._graph, self._graph_execution) - - # === Graph Traversal Components === - # Propagates skip status through the graph when conditions aren't met - self._skip_propagator = SkipPropagator( - graph=self._graph, - state_manager=self._state_manager, - ) - - # Processes edges to determine next nodes after execution - # Also handles conditional branching and route selection - self._edge_processor = EdgeProcessor( - graph=self._graph, - state_manager=self._state_manager, - response_coordinator=self._response_coordinator, - skip_propagator=self._skip_propagator, - ) - - # === Command Processing === - # Processes external commands (e.g., abort requests) - self._command_processor = CommandProcessor( - command_channel=self._command_channel, - graph_execution=self._graph_execution, - ) - - # Register command handlers - abort_handler = AbortCommandHandler() - self._command_processor.register_handler(AbortCommand, abort_handler) - - pause_handler = PauseCommandHandler() - self._command_processor.register_handler(PauseCommand, pause_handler) - - update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool) - self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler) - - # === Worker Pool Setup === - # Create worker pool for parallel node execution - self._worker_pool = WorkerPool( - ready_queue=self._ready_queue, - event_queue=self._event_queue, - graph=self._graph, - layers=self._layers, - execution_context=self._graph_runtime_state.execution_context, - config=self._config, - ) - - # === Orchestration === - # Coordinates the overall execution lifecycle - self._execution_coordinator = ExecutionCoordinator( - graph_execution=self._graph_execution, - state_manager=self._state_manager, - command_processor=self._command_processor, - worker_pool=self._worker_pool, - ) - - # === Event Handler Registry === - # Central registry for handling all node execution events - self._event_handler_registry = EventHandler( - graph=self._graph, - graph_runtime_state=self._graph_runtime_state, - graph_execution=self._graph_execution, - response_coordinator=self._response_coordinator, - event_collector=self._event_manager, - edge_processor=self._edge_processor, - state_manager=self._state_manager, - error_handler=self._error_handler, - ) - - # Dispatches events and manages execution flow - self._dispatcher = Dispatcher( - event_queue=self._event_queue, - event_handler=self._event_handler_registry, - execution_coordinator=self._execution_coordinator, - event_emitter=self._event_manager, - ) - - # === Validation === - # Ensure all nodes share the same GraphRuntimeState instance - self._validate_graph_state_consistency() - - def _validate_graph_state_consistency(self) -> None: - """Validate that all nodes share the same GraphRuntimeState.""" - expected_state_id = id(self._graph_runtime_state) - for node in self._graph.nodes.values(): - if id(node.graph_runtime_state) != expected_state_id: - raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance") - - def _bind_layer_context( - self, - layer: GraphEngineLayer, - ) -> None: - layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel) - - def layer(self, layer: GraphEngineLayer) -> GraphEngine: - """Add a layer for extending functionality.""" - self._layers.append(layer) - self._bind_layer_context(layer) - return self - - def request_abort(self, reason: str | None = None) -> None: - """Queue an abort command for this engine.""" - self._command_channel.send_command(AbortCommand(reason=reason or "User requested abort")) - - def create_child_engine( - self, - *, - workflow_id: str, - graph_init_params: GraphInitParams, - root_node_id: str, - variable_pool: VariablePool | None = None, - ) -> GraphEngine: - return self._graph_runtime_state.create_child_engine( - workflow_id=workflow_id, - graph_init_params=graph_init_params, - root_node_id=root_node_id, - variable_pool=variable_pool, - ) - - def run(self) -> Generator[GraphEngineEvent, None, None]: - """ - Execute the graph using the modular architecture. - - Returns: - Generator yielding GraphEngineEvent instances - """ - try: - # Initialize layers - self._initialize_layers() - - is_resume = self._graph_execution.started - if not is_resume: - self._graph_execution.start() - else: - self._graph_execution.paused = False - self._graph_execution.pause_reasons = [] - - start_event = GraphRunStartedEvent( - reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL, - ) - self._event_manager.notify_layers(start_event) - yield start_event - - # Start subsystems - self._start_execution(resume=is_resume) - - # Yield events as they occur - yield from self._event_manager.emit_events() - - # Handle completion - if self._graph_execution.is_paused: - pause_reasons = self._graph_execution.pause_reasons - assert pause_reasons, "pause_reasons should not be empty when execution is paused." - # Ensure we have a valid PauseReason for the event - paused_event = GraphRunPausedEvent( - reasons=pause_reasons, - outputs=self._graph_runtime_state.outputs, - ) - self._event_manager.notify_layers(paused_event) - yield paused_event - elif self._graph_execution.aborted: - abort_reason = "Workflow execution aborted by user command" - if self._graph_execution.error: - abort_reason = str(self._graph_execution.error) - aborted_event = GraphRunAbortedEvent( - reason=abort_reason, - outputs=self._graph_runtime_state.outputs, - ) - self._event_manager.notify_layers(aborted_event) - yield aborted_event - elif self._graph_execution.has_error: - if self._graph_execution.error: - raise self._graph_execution.error - else: - outputs = self._graph_runtime_state.outputs - exceptions_count = self._graph_execution.exceptions_count - if exceptions_count > 0: - partial_event = GraphRunPartialSucceededEvent( - exceptions_count=exceptions_count, - outputs=outputs, - ) - self._event_manager.notify_layers(partial_event) - yield partial_event - else: - succeeded_event = GraphRunSucceededEvent( - outputs=outputs, - ) - self._event_manager.notify_layers(succeeded_event) - yield succeeded_event - - except Exception as e: - failed_event = GraphRunFailedEvent( - error=str(e), - exceptions_count=self._graph_execution.exceptions_count, - ) - self._event_manager.notify_layers(failed_event) - yield failed_event - raise - - finally: - self._stop_execution() - - def _initialize_layers(self) -> None: - """Initialize layers with context.""" - self._event_manager.set_layers(self._layers) - for layer in self._layers: - try: - layer.on_graph_start() - except Exception: - logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__) - - def _start_execution(self, *, resume: bool = False) -> None: - """Start execution subsystems.""" - paused_nodes: list[str] = [] - deferred_nodes: list[str] = [] - if resume: - paused_nodes = self._graph_runtime_state.consume_paused_nodes() - deferred_nodes = self._graph_runtime_state.consume_deferred_nodes() - - # Start worker pool (it calculates initial workers internally) - self._worker_pool.start() - - # Register response nodes - for node in self._graph.nodes.values(): - if node.execution_type == NodeExecutionType.RESPONSE: - self._response_coordinator.register(node.id) - - if not resume: - # Enqueue root node - root_node = self._graph.root_node - self._state_manager.enqueue_node(root_node.id) - self._state_manager.start_execution(root_node.id) - else: - seen_nodes: set[str] = set() - for node_id in paused_nodes + deferred_nodes: - if node_id in seen_nodes: - continue - seen_nodes.add(node_id) - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) - - # Start dispatcher - self._dispatcher.start() - - def _stop_execution(self) -> None: - """Stop execution subsystems.""" - self._dispatcher.stop() - self._worker_pool.stop() - # Don't mark complete here as the dispatcher already does it - - # Notify layers - for layer in self._layers: - try: - layer.on_graph_end(self._graph_execution.error) - except Exception: - logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__) - - # Public property accessors for attributes that need external access - @property - def graph_runtime_state(self) -> GraphRuntimeState: - """Get the graph runtime state.""" - return self._graph_runtime_state diff --git a/api/graphon/graph_engine/graph_state_manager.py b/api/graphon/graph_engine/graph_state_manager.py deleted file mode 100644 index ade8e403a8..0000000000 --- a/api/graphon/graph_engine/graph_state_manager.py +++ /dev/null @@ -1,290 +0,0 @@ -""" -Graph state manager that combines node, edge, and execution tracking. -""" - -import threading -from collections.abc import Sequence -from typing import TypedDict, final - -from graphon.enums import NodeState -from graphon.graph import Edge, Graph - -from .ready_queue import ReadyQueue - - -class EdgeStateAnalysis(TypedDict): - """Analysis result for edge states.""" - - has_unknown: bool - has_taken: bool - all_skipped: bool - - -@final -class GraphStateManager: - def __init__(self, graph: Graph, ready_queue: ReadyQueue) -> None: - """ - Initialize the state manager. - - Args: - graph: The workflow graph - ready_queue: Queue for nodes ready to execute - """ - self._graph = graph - self._ready_queue = ready_queue - self._lock = threading.RLock() - - # Execution tracking state - self._executing_nodes: set[str] = set() - - # ============= Node State Operations ============= - - def enqueue_node(self, node_id: str) -> None: - """ - Mark a node as TAKEN and add it to the ready queue. - - This combines the state transition and enqueueing operations - that always occur together when preparing a node for execution. - - Args: - node_id: The ID of the node to enqueue - """ - with self._lock: - self._graph.nodes[node_id].state = NodeState.TAKEN - self._ready_queue.put(node_id) - - def mark_node_skipped(self, node_id: str) -> None: - """ - Mark a node as SKIPPED. - - Args: - node_id: The ID of the node to skip - """ - with self._lock: - self._graph.nodes[node_id].state = NodeState.SKIPPED - - def is_node_ready(self, node_id: str) -> bool: - """ - Check if a node is ready to be executed. - - A node is ready when all its incoming edges from taken branches - have been satisfied. - - Args: - node_id: The ID of the node to check - - Returns: - True if the node is ready for execution - """ - with self._lock: - # Get all incoming edges to this node - incoming_edges = self._graph.get_incoming_edges(node_id) - - # If no incoming edges, node is always ready - if not incoming_edges: - return True - - # If any edge is UNKNOWN, node is not ready - if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges): - return False - - # Node is ready if at least one edge is TAKEN - return any(edge.state == NodeState.TAKEN for edge in incoming_edges) - - def get_node_state(self, node_id: str) -> NodeState: - """ - Get the current state of a node. - - Args: - node_id: The ID of the node - - Returns: - The current node state - """ - with self._lock: - return self._graph.nodes[node_id].state - - # ============= Edge State Operations ============= - - def mark_edge_taken(self, edge_id: str) -> None: - """ - Mark an edge as TAKEN. - - Args: - edge_id: The ID of the edge to mark - """ - with self._lock: - self._graph.edges[edge_id].state = NodeState.TAKEN - - def mark_edge_skipped(self, edge_id: str) -> None: - """ - Mark an edge as SKIPPED. - - Args: - edge_id: The ID of the edge to mark - """ - with self._lock: - self._graph.edges[edge_id].state = NodeState.SKIPPED - - def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis: - """ - Analyze the states of edges and return summary flags. - - Args: - edges: List of edges to analyze - - Returns: - Analysis result with state flags - """ - with self._lock: - states = {edge.state for edge in edges} - - return EdgeStateAnalysis( - has_unknown=NodeState.UNKNOWN in states, - has_taken=NodeState.TAKEN in states, - all_skipped=states == {NodeState.SKIPPED} if states else True, - ) - - def get_edge_state(self, edge_id: str) -> NodeState: - """ - Get the current state of an edge. - - Args: - edge_id: The ID of the edge - - Returns: - The current edge state - """ - with self._lock: - return self._graph.edges[edge_id].state - - def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]: - """ - Categorize branch edges into selected and unselected. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected edge - - Returns: - A tuple of (selected_edges, unselected_edges) - """ - with self._lock: - outgoing_edges = self._graph.get_outgoing_edges(node_id) - selected_edges: list[Edge] = [] - unselected_edges: list[Edge] = [] - - for edge in outgoing_edges: - if edge.source_handle == selected_handle: - selected_edges.append(edge) - else: - unselected_edges.append(edge) - - return selected_edges, unselected_edges - - # ============= Execution Tracking Operations ============= - - def start_execution(self, node_id: str) -> None: - """ - Mark a node as executing. - - Args: - node_id: The ID of the node starting execution - """ - with self._lock: - self._executing_nodes.add(node_id) - - def finish_execution(self, node_id: str) -> None: - """ - Mark a node as no longer executing. - - Args: - node_id: The ID of the node finishing execution - """ - with self._lock: - self._executing_nodes.discard(node_id) - - def is_executing(self, node_id: str) -> bool: - """ - Check if a node is currently executing. - - Args: - node_id: The ID of the node to check - - Returns: - True if the node is executing - """ - with self._lock: - return node_id in self._executing_nodes - - def get_executing_count(self) -> int: - """ - Get the count of currently executing nodes. - - Returns: - Number of executing nodes - """ - # This count is a best-effort snapshot and can change concurrently. - # Only use it for pause-drain checks where scheduling is already frozen. - with self._lock: - return len(self._executing_nodes) - - def get_executing_nodes(self) -> set[str]: - """ - Get a copy of the set of executing node IDs. - - Returns: - Set of node IDs currently executing - """ - with self._lock: - return self._executing_nodes.copy() - - def clear_executing(self) -> None: - """Clear all executing nodes.""" - with self._lock: - self._executing_nodes.clear() - - # ============= Composite Operations ============= - - def is_execution_complete(self) -> bool: - """ - Check if graph execution is complete. - - Execution is complete when: - - Ready queue is empty - - No nodes are executing - - Returns: - True if execution is complete - """ - with self._lock: - return self._ready_queue.empty() and len(self._executing_nodes) == 0 - - def get_queue_depth(self) -> int: - """ - Get the current depth of the ready queue. - - Returns: - Number of nodes in the ready queue - """ - return self._ready_queue.qsize() - - def get_execution_stats(self) -> dict[str, int]: - """ - Get execution statistics. - - Returns: - Dictionary with execution statistics - """ - with self._lock: - taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN) - skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED) - unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN) - - return { - "queue_depth": self._ready_queue.qsize(), - "executing": len(self._executing_nodes), - "taken_nodes": taken_nodes, - "skipped_nodes": skipped_nodes, - "unknown_nodes": unknown_nodes, - } diff --git a/api/graphon/graph_engine/graph_traversal/__init__.py b/api/graphon/graph_engine/graph_traversal/__init__.py deleted file mode 100644 index d629140d06..0000000000 --- a/api/graphon/graph_engine/graph_traversal/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Graph traversal subsystem for graph engine. - -This package handles graph navigation, edge processing, -and skip propagation logic. -""" - -from .edge_processor import EdgeProcessor -from .skip_propagator import SkipPropagator - -__all__ = [ - "EdgeProcessor", - "SkipPropagator", -] diff --git a/api/graphon/graph_engine/graph_traversal/edge_processor.py b/api/graphon/graph_engine/graph_traversal/edge_processor.py deleted file mode 100644 index e51eee8a69..0000000000 --- a/api/graphon/graph_engine/graph_traversal/edge_processor.py +++ /dev/null @@ -1,201 +0,0 @@ -""" -Edge processing logic for graph traversal. -""" - -from collections.abc import Sequence -from typing import TYPE_CHECKING, final - -from graphon.enums import NodeExecutionType -from graphon.graph import Edge, Graph -from graphon.graph_events import NodeRunStreamChunkEvent - -from ..graph_state_manager import GraphStateManager -from ..response_coordinator import ResponseStreamCoordinator - -if TYPE_CHECKING: - from .skip_propagator import SkipPropagator - - -@final -class EdgeProcessor: - """ - Processes edges during graph execution. - - This handles marking edges as taken or skipped, notifying - the response coordinator, triggering downstream node execution, - and managing branch node logic. - """ - - def __init__( - self, - graph: Graph, - state_manager: GraphStateManager, - response_coordinator: ResponseStreamCoordinator, - skip_propagator: "SkipPropagator", - ) -> None: - """ - Initialize the edge processor. - - Args: - graph: The workflow graph - state_manager: Unified state manager - response_coordinator: Response stream coordinator - skip_propagator: Propagator for skip states - """ - self._graph = graph - self._state_manager = state_manager - self._response_coordinator = response_coordinator - self._skip_propagator = skip_propagator - - def process_node_success( - self, node_id: str, selected_handle: str | None = None - ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Process edges after a node succeeds. - - Args: - node_id: The ID of the succeeded node - selected_handle: For branch nodes, the selected edge handle - - Returns: - Tuple of (list of downstream node IDs that are now ready, list of streaming events) - """ - node = self._graph.nodes[node_id] - - if node.execution_type == NodeExecutionType.BRANCH: - return self._process_branch_node_edges(node_id, selected_handle) - else: - return self._process_non_branch_node_edges(node_id) - - def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Process edges for non-branch nodes (mark all as TAKEN). - - Args: - node_id: The ID of the succeeded node - - Returns: - Tuple of (list of downstream nodes ready for execution, list of streaming events) - """ - ready_nodes: list[str] = [] - all_streaming_events: list[NodeRunStreamChunkEvent] = [] - outgoing_edges = self._graph.get_outgoing_edges(node_id) - - for edge in outgoing_edges: - nodes, events = self._process_taken_edge(edge) - ready_nodes.extend(nodes) - all_streaming_events.extend(events) - - return ready_nodes, all_streaming_events - - def _process_branch_node_edges( - self, node_id: str, selected_handle: str | None - ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Process edges for branch nodes. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected edge - - Returns: - Tuple of (list of downstream nodes ready for execution, list of streaming events) - - Raises: - ValueError: If no edge was selected - """ - if not selected_handle: - raise ValueError(f"Branch node {node_id} did not select any edge") - - ready_nodes: list[str] = [] - all_streaming_events: list[NodeRunStreamChunkEvent] = [] - - # Categorize edges - selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle) - - # Process unselected edges first (mark as skipped) - for edge in unselected_edges: - self._process_skipped_edge(edge) - - # Process selected edges - for edge in selected_edges: - nodes, events = self._process_taken_edge(edge) - ready_nodes.extend(nodes) - all_streaming_events.extend(events) - - return ready_nodes, all_streaming_events - - def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Mark edge as taken and check downstream node. - - Args: - edge: The edge to process - - Returns: - Tuple of (list containing downstream node ID if it's ready, list of streaming events) - """ - # Mark edge as taken - self._state_manager.mark_edge_taken(edge.id) - - # Notify response coordinator and get streaming events - streaming_events = self._response_coordinator.on_edge_taken(edge.id) - - # Check if downstream node is ready - ready_nodes: list[str] = [] - if self._state_manager.is_node_ready(edge.head): - ready_nodes.append(edge.head) - - return ready_nodes, streaming_events - - def _process_skipped_edge(self, edge: Edge) -> None: - """ - Mark edge as skipped. - - Args: - edge: The edge to skip - """ - self._state_manager.mark_edge_skipped(edge.id) - - def handle_branch_completion( - self, node_id: str, selected_handle: str | None - ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Handle completion of a branch node. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected branch - - Returns: - Tuple of (list of downstream nodes ready for execution, list of streaming events) - - Raises: - ValueError: If no branch was selected - """ - if not selected_handle: - raise ValueError(f"Branch node {node_id} completed without selecting a branch") - - # Categorize edges into selected and unselected - _, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle) - - # Skip all unselected paths - self._skip_propagator.skip_branch_paths(unselected_edges) - - # Process selected edges and get ready nodes and streaming events - return self.process_node_success(node_id, selected_handle) - - def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool: - """ - Validate that a branch selection is valid. - - Args: - node_id: The ID of the branch node - selected_handle: The handle to validate - - Returns: - True if the selection is valid - """ - outgoing_edges = self._graph.get_outgoing_edges(node_id) - valid_handles = {edge.source_handle for edge in outgoing_edges} - return selected_handle in valid_handles diff --git a/api/graphon/graph_engine/graph_traversal/skip_propagator.py b/api/graphon/graph_engine/graph_traversal/skip_propagator.py deleted file mode 100644 index bdb83b38ad..0000000000 --- a/api/graphon/graph_engine/graph_traversal/skip_propagator.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -Skip state propagation through the graph. -""" - -from collections.abc import Sequence -from typing import final - -from graphon.graph import Edge, Graph - -from ..graph_state_manager import GraphStateManager - - -@final -class SkipPropagator: - """ - Propagates skip states through the graph. - - When a node is skipped, this ensures all downstream nodes - that depend solely on it are also skipped. - """ - - def __init__( - self, - graph: Graph, - state_manager: GraphStateManager, - ) -> None: - """ - Initialize the skip propagator. - - Args: - graph: The workflow graph - state_manager: Unified state manager - """ - self._graph = graph - self._state_manager = state_manager - - def propagate_skip_from_edge(self, edge_id: str) -> None: - """ - Recursively propagate skip state from a skipped edge. - - Rules: - - If a node has any UNKNOWN incoming edges, stop processing - - If all incoming edges are SKIPPED, skip the node and its edges - - If any incoming edge is TAKEN, the node may still execute - - Args: - edge_id: The ID of the skipped edge to start from - """ - downstream_node_id = self._graph.edges[edge_id].head - incoming_edges = self._graph.get_incoming_edges(downstream_node_id) - - # Analyze edge states - edge_states = self._state_manager.analyze_edge_states(incoming_edges) - - # Stop if there are unknown edges (not yet processed) - if edge_states["has_unknown"]: - return - - # If any edge is taken, node may still execute - if edge_states["has_taken"]: - # Enqueue node - self._state_manager.enqueue_node(downstream_node_id) - self._state_manager.start_execution(downstream_node_id) - return - - # All edges are skipped, propagate skip to this node - if edge_states["all_skipped"]: - self._propagate_skip_to_node(downstream_node_id) - - def _propagate_skip_to_node(self, node_id: str) -> None: - """ - Mark a node and all its outgoing edges as skipped. - - Args: - node_id: The ID of the node to skip - """ - # Mark node as skipped - self._state_manager.mark_node_skipped(node_id) - - # Mark all outgoing edges as skipped and propagate - outgoing_edges = self._graph.get_outgoing_edges(node_id) - for edge in outgoing_edges: - self._state_manager.mark_edge_skipped(edge.id) - # Recursively propagate skip - self.propagate_skip_from_edge(edge.id) - - def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None: - """ - Skip all paths from unselected branch edges. - - Args: - unselected_edges: List of edges not taken by the branch - """ - for edge in unselected_edges: - self._state_manager.mark_edge_skipped(edge.id) - self.propagate_skip_from_edge(edge.id) diff --git a/api/graphon/graph_engine/layers/README.md b/api/graphon/graph_engine/layers/README.md deleted file mode 100644 index b0f295037c..0000000000 --- a/api/graphon/graph_engine/layers/README.md +++ /dev/null @@ -1,55 +0,0 @@ -# Layers - -Pluggable middleware for engine extensions. - -## Components - -### Layer (base) - -Abstract base class for layers. - -- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks) -- `on_graph_start()` - Execution start hook -- `on_event()` - Process all events -- `on_graph_end()` - Execution end hook - -### DebugLoggingLayer - -Comprehensive execution logging. - -- Configurable detail levels -- Tracks execution statistics -- Truncates long values - -## Usage - -```python -debug_layer = DebugLoggingLayer( - level="INFO", - include_outputs=True -) - -engine = GraphEngine(graph) -engine.layer(debug_layer) -engine.run() -``` - -`engine.layer()` binds the read-only runtime state before execution, so -`graph_runtime_state` is always available inside layer hooks. - -## Custom Layers - -```python -class MetricsLayer(Layer): - def on_event(self, event): - if isinstance(event, NodeRunSucceededEvent): - self.metrics[event.node_id] = event.elapsed_time -``` - -## Configuration - -**DebugLoggingLayer Options:** - -- `level` - Log level (INFO, DEBUG, ERROR) -- `include_inputs/outputs` - Log data values -- `max_value_length` - Truncate long values diff --git a/api/graphon/graph_engine/layers/__init__.py b/api/graphon/graph_engine/layers/__init__.py deleted file mode 100644 index 0a29a52993..0000000000 --- a/api/graphon/graph_engine/layers/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Layer system for GraphEngine extensibility. - -This module provides the layer infrastructure for extending GraphEngine functionality -with middleware-like components that can observe events and interact with execution. -""" - -from .base import GraphEngineLayer -from .debug_logging import DebugLoggingLayer -from .execution_limits import ExecutionLimitsLayer - -__all__ = [ - "DebugLoggingLayer", - "ExecutionLimitsLayer", - "GraphEngineLayer", -] diff --git a/api/graphon/graph_engine/layers/base.py b/api/graphon/graph_engine/layers/base.py deleted file mode 100644 index 605615d347..0000000000 --- a/api/graphon/graph_engine/layers/base.py +++ /dev/null @@ -1,128 +0,0 @@ -""" -Base layer class for GraphEngine extensions. - -This module provides the abstract base class for implementing layers that can -intercept and respond to GraphEngine events. -""" - -from abc import ABC, abstractmethod - -from graphon.graph_engine.protocols.command_channel import CommandChannel -from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.runtime import ReadOnlyGraphRuntimeState - - -class GraphEngineLayerNotInitializedError(Exception): - """Raised when a layer's runtime state is accessed before initialization.""" - - def __init__(self, layer_name: str | None = None) -> None: - name = layer_name or "GraphEngineLayer" - super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.") - - -class GraphEngineLayer(ABC): - """ - Abstract base class for GraphEngine layers. - - Layers are middleware-like components that can: - - Observe all events emitted by the GraphEngine - - Access the graph runtime state - - Send commands to control execution - - Subclasses should override the constructor to accept configuration parameters, - then implement the three lifecycle methods. - """ - - def __init__(self) -> None: - """Initialize the layer. Subclasses can override with custom parameters.""" - self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None - self.command_channel: CommandChannel | None = None - - @property - def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState: - if self._graph_runtime_state is None: - raise GraphEngineLayerNotInitializedError(type(self).__name__) - return self._graph_runtime_state - - def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None: - """ - Initialize the layer with engine dependencies. - - Called by GraphEngine to inject the read-only runtime state and command channel. - This is invoked when the layer is registered with a `GraphEngine` instance. - Implementations should be idempotent. - Args: - graph_runtime_state: Read-only view of the runtime state - command_channel: Channel for sending commands to the engine - """ - self._graph_runtime_state = graph_runtime_state - self.command_channel = command_channel - - @abstractmethod - def on_graph_start(self) -> None: - """ - Called when graph execution starts. - - This is called after the engine has been initialized but before any nodes - are executed. Layers can use this to set up resources or log start information. - """ - pass - - @abstractmethod - def on_event(self, event: GraphEngineEvent) -> None: - """ - Called for every event emitted by the engine. - - This method receives all events generated during graph execution, including: - - Graph lifecycle events (start, success, failure) - - Node execution events (start, success, failure, retry) - - Stream events for response nodes - - Container events (iteration, loop) - - Args: - event: The event emitted by the engine - """ - pass - - @abstractmethod - def on_graph_end(self, error: Exception | None) -> None: - """ - Called when graph execution ends. - - This is called after all nodes have been executed or when execution is - aborted. Layers can use this to clean up resources or log final state. - - Args: - error: The exception that caused execution to fail, or None if successful - """ - pass - - def on_node_run_start(self, node: Node) -> None: - """ - Called immediately before a node begins execution. - - Layers can override to inject behavior (e.g., start spans) prior to node execution. - The node's execution ID is available via `node._node_execution_id` and will be - consistent with all events emitted by this node execution. - - Args: - node: The node instance about to be executed - """ - return - - def on_node_run_end( - self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None - ) -> None: - """ - Called after a node finishes execution. - - The node's execution ID is available via `node._node_execution_id` and matches - the `id` field in all events emitted by this node execution. - - Args: - node: The node instance that just finished execution - error: Exception instance if the node failed, otherwise None - result_event: The final result event from node execution (succeeded/failed/paused), if any - """ - return diff --git a/api/graphon/graph_engine/layers/debug_logging.py b/api/graphon/graph_engine/layers/debug_logging.py deleted file mode 100644 index e6585fb3b9..0000000000 --- a/api/graphon/graph_engine/layers/debug_logging.py +++ /dev/null @@ -1,247 +0,0 @@ -""" -Debug logging layer for GraphEngine. - -This module provides a layer that logs all events and state changes during -graph execution for debugging purposes. -""" - -import logging -from collections.abc import Mapping -from typing import Any, final - -from typing_extensions import override - -from graphon.graph_events import ( - GraphEngineEvent, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .base import GraphEngineLayer - - -@final -class DebugLoggingLayer(GraphEngineLayer): - """ - A layer that provides comprehensive logging of GraphEngine execution. - - This layer logs all events with configurable detail levels, helping developers - debug workflow execution and understand the flow of events. - """ - - def __init__( - self, - level: str = "INFO", - include_inputs: bool = False, - include_outputs: bool = True, - include_process_data: bool = False, - logger_name: str = "GraphEngine.Debug", - max_value_length: int = 500, - ) -> None: - """ - Initialize the debug logging layer. - - Args: - level: Logging level (DEBUG, INFO, WARNING, ERROR) - include_inputs: Whether to log node input values - include_outputs: Whether to log node output values - include_process_data: Whether to log node process data - logger_name: Name of the logger to use - max_value_length: Maximum length of logged values (truncated if longer) - """ - super().__init__() - self.level = level - self.include_inputs = include_inputs - self.include_outputs = include_outputs - self.include_process_data = include_process_data - self.max_value_length = max_value_length - - # Set up logger - self.logger = logging.getLogger(logger_name) - log_level = getattr(logging, level.upper(), logging.INFO) - self.logger.setLevel(log_level) - - # Track execution stats - self.node_count = 0 - self.success_count = 0 - self.failure_count = 0 - self.retry_count = 0 - - def _truncate_value(self, value: Any) -> str: - """Truncate long values for logging.""" - str_value = str(value) - if len(str_value) > self.max_value_length: - return str_value[: self.max_value_length] + "... (truncated)" - return str_value - - def _format_dict(self, data: dict[str, Any] | Mapping[str, Any]) -> str: - """Format a dictionary or mapping for logging with truncation.""" - if not data: - return "{}" - - formatted_items: list[str] = [] - for key, value in data.items(): - formatted_value = self._truncate_value(value) - formatted_items.append(f" {key}: {formatted_value}") - - return "{\n" + ",\n".join(formatted_items) + "\n}" - - @override - def on_graph_start(self) -> None: - """Log graph execution start.""" - self.logger.info("=" * 80) - self.logger.info("🚀 GRAPH EXECUTION STARTED") - self.logger.info("=" * 80) - # Log initial state - self.logger.info("Initial State:") - - @override - def on_event(self, event: GraphEngineEvent) -> None: - """Log individual events based on their type.""" - event_class = event.__class__.__name__ - - # Graph-level events - if isinstance(event, GraphRunStartedEvent): - self.logger.debug("Graph run started event") - - elif isinstance(event, GraphRunSucceededEvent): - self.logger.info("✅ Graph run succeeded") - if self.include_outputs and event.outputs: - self.logger.info(" Final outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, GraphRunPartialSucceededEvent): - self.logger.warning("⚠️ Graph run partially succeeded") - if event.exceptions_count > 0: - self.logger.warning(" Total exceptions: %s", event.exceptions_count) - if self.include_outputs and event.outputs: - self.logger.info(" Final outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, GraphRunFailedEvent): - self.logger.error("❌ Graph run failed: %s", event.error) - if event.exceptions_count > 0: - self.logger.error(" Total exceptions: %s", event.exceptions_count) - - elif isinstance(event, GraphRunAbortedEvent): - self.logger.warning("⚠️ Graph run aborted: %s", event.reason) - if event.outputs: - self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs)) - - # Node-level events - # Retry before Started because Retry subclasses Started; - elif isinstance(event, NodeRunRetryEvent): - self.retry_count += 1 - self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index) - self.logger.warning(" Previous error: %s", event.error) - - elif isinstance(event, NodeRunStartedEvent): - self.node_count += 1 - self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type) - - if self.include_inputs and event.node_run_result.inputs: - self.logger.debug(" Inputs: %s", self._format_dict(event.node_run_result.inputs)) - - elif isinstance(event, NodeRunSucceededEvent): - self.success_count += 1 - self.logger.info("✅ Node succeeded: %s", event.node_id) - - if self.include_outputs and event.node_run_result.outputs: - self.logger.debug(" Outputs: %s", self._format_dict(event.node_run_result.outputs)) - - if self.include_process_data and event.node_run_result.process_data: - self.logger.debug(" Process data: %s", self._format_dict(event.node_run_result.process_data)) - - elif isinstance(event, NodeRunFailedEvent): - self.failure_count += 1 - self.logger.error("❌ Node failed: %s", event.node_id) - self.logger.error(" Error: %s", event.error) - - if event.node_run_result.error: - self.logger.error(" Details: %s", event.node_run_result.error) - - elif isinstance(event, NodeRunExceptionEvent): - self.logger.warning("⚠️ Node exception handled: %s", event.node_id) - self.logger.warning(" Error: %s", event.error) - - elif isinstance(event, NodeRunStreamChunkEvent): - # Log stream chunks at debug level to avoid spam - final_indicator = " (FINAL)" if event.is_final else "" - self.logger.debug( - "📝 Stream chunk from %s%s: %s", event.node_id, final_indicator, self._truncate_value(event.chunk) - ) - - # Iteration events - elif isinstance(event, NodeRunIterationStartedEvent): - self.logger.info("🔁 Iteration started: %s", event.node_id) - - elif isinstance(event, NodeRunIterationNextEvent): - self.logger.debug(" Iteration next: %s (index: %s)", event.node_id, event.index) - - elif isinstance(event, NodeRunIterationSucceededEvent): - self.logger.info("✅ Iteration succeeded: %s", event.node_id) - if self.include_outputs and event.outputs: - self.logger.debug(" Outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, NodeRunIterationFailedEvent): - self.logger.error("❌ Iteration failed: %s", event.node_id) - self.logger.error(" Error: %s", event.error) - - # Loop events - elif isinstance(event, NodeRunLoopStartedEvent): - self.logger.info("🔄 Loop started: %s", event.node_id) - - elif isinstance(event, NodeRunLoopNextEvent): - self.logger.debug(" Loop iteration: %s (index: %s)", event.node_id, event.index) - - elif isinstance(event, NodeRunLoopSucceededEvent): - self.logger.info("✅ Loop succeeded: %s", event.node_id) - if self.include_outputs and event.outputs: - self.logger.debug(" Outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, NodeRunLoopFailedEvent): - self.logger.error("❌ Loop failed: %s", event.node_id) - self.logger.error(" Error: %s", event.error) - - else: - # Log unknown events at debug level - self.logger.debug("Event: %s", event_class) - - @override - def on_graph_end(self, error: Exception | None) -> None: - """Log graph execution end with summary statistics.""" - self.logger.info("=" * 80) - - if error: - self.logger.error("🔴 GRAPH EXECUTION FAILED") - self.logger.error(" Error: %s", error) - else: - self.logger.info("🎉 GRAPH EXECUTION COMPLETED SUCCESSFULLY") - - # Log execution statistics - self.logger.info("Execution Statistics:") - self.logger.info(" Total nodes executed: %s", self.node_count) - self.logger.info(" Successful nodes: %s", self.success_count) - self.logger.info(" Failed nodes: %s", self.failure_count) - self.logger.info(" Node retries: %s", self.retry_count) - - # Log final state if available - if self.include_outputs and self.graph_runtime_state.outputs: - self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs)) - - self.logger.info("=" * 80) diff --git a/api/graphon/graph_engine/layers/execution_limits.py b/api/graphon/graph_engine/layers/execution_limits.py deleted file mode 100644 index 2742b3acd3..0000000000 --- a/api/graphon/graph_engine/layers/execution_limits.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -Execution limits layer for GraphEngine. - -This layer monitors workflow execution to enforce limits on: -- Maximum execution steps -- Maximum execution time - -When limits are exceeded, the layer automatically aborts execution. -""" - -import logging -import time -from enum import StrEnum -from typing import final - -from typing_extensions import override - -from graphon.graph_engine.entities.commands import AbortCommand, CommandType -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import ( - GraphEngineEvent, - NodeRunStartedEvent, -) -from graphon.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent - - -class LimitType(StrEnum): - """Types of execution limits that can be exceeded.""" - - STEP_LIMIT = "step_limit" - TIME_LIMIT = "time_limit" - - -@final -class ExecutionLimitsLayer(GraphEngineLayer): - """ - Layer that enforces execution limits for workflows. - - Monitors: - - Step count: Tracks number of node executions - - Time limit: Monitors total execution time - - Automatically aborts execution when limits are exceeded. - """ - - def __init__(self, max_steps: int, max_time: int) -> None: - """ - Initialize the execution limits layer. - - Args: - max_steps: Maximum number of execution steps allowed - max_time: Maximum execution time in seconds allowed - """ - super().__init__() - self.max_steps = max_steps - self.max_time = max_time - - # Runtime tracking - self.start_time: float | None = None - self.step_count = 0 - self.logger = logging.getLogger(__name__) - - # State tracking - self._execution_started = False - self._execution_ended = False - self._abort_sent = False # Track if abort command has been sent - - @override - def on_graph_start(self) -> None: - """Called when graph execution starts.""" - self.start_time = time.time() - self.step_count = 0 - self._execution_started = True - self._execution_ended = False - self._abort_sent = False - - self.logger.debug("Execution limits monitoring started") - - @override - def on_event(self, event: GraphEngineEvent) -> None: - """ - Called for every event emitted by the engine. - - Monitors execution progress and enforces limits. - """ - if not self._execution_started or self._execution_ended or self._abort_sent: - return - - # Track step count for node execution events - if isinstance(event, NodeRunStartedEvent): - self.step_count += 1 - self.logger.debug("Step %d started: %s", self.step_count, event.node_id) - - # Check step limit when node execution completes - if isinstance(event, NodeRunSucceededEvent | NodeRunFailedEvent): - if self._reached_step_limitation(): - self._send_abort_command(LimitType.STEP_LIMIT) - - if self._reached_time_limitation(): - self._send_abort_command(LimitType.TIME_LIMIT) - - @override - def on_graph_end(self, error: Exception | None) -> None: - """Called when graph execution ends.""" - if self._execution_started and not self._execution_ended: - self._execution_ended = True - - if self.start_time: - total_time = time.time() - self.start_time - self.logger.debug("Execution completed: %d steps in %.2f seconds", self.step_count, total_time) - - def _reached_step_limitation(self) -> bool: - """Check if step count limit has been exceeded.""" - return self.step_count > self.max_steps - - def _reached_time_limitation(self) -> bool: - """Check if time limit has been exceeded.""" - return self.start_time is not None and (time.time() - self.start_time) > self.max_time - - def _send_abort_command(self, limit_type: LimitType) -> None: - """ - Send abort command due to limit violation. - - Args: - limit_type: Type of limit exceeded - """ - if not self.command_channel or not self._execution_started or self._execution_ended or self._abort_sent: - return - - # Format detailed reason message - if limit_type == LimitType.STEP_LIMIT: - reason = f"Maximum execution steps exceeded: {self.step_count} > {self.max_steps}" - elif limit_type == LimitType.TIME_LIMIT: - elapsed_time = time.time() - self.start_time if self.start_time else 0 - reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s" - - self.logger.warning("Execution limit exceeded: %s", reason) - - try: - # Send abort command to the engine - abort_command = AbortCommand(command_type=CommandType.ABORT, reason=reason) - self.command_channel.send_command(abort_command) - - # Mark that abort has been sent to prevent duplicate commands - self._abort_sent = True - - self.logger.debug("Abort command sent to engine") - - except Exception: - self.logger.exception("Failed to send abort command") diff --git a/api/graphon/graph_engine/manager.py b/api/graphon/graph_engine/manager.py deleted file mode 100644 index c728ff6986..0000000000 --- a/api/graphon/graph_engine/manager.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -GraphEngine Manager for sending control commands via Redis channel. - -This module provides a simplified interface for controlling workflow executions -using the new Redis command channel, without requiring user permission checks. -Callers must provide a Redis client dependency from outside the workflow package. -""" - -import logging -from collections.abc import Sequence -from typing import final - -from graphon.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol -from graphon.graph_engine.entities.commands import ( - AbortCommand, - GraphEngineCommand, - PauseCommand, - UpdateVariablesCommand, - VariableUpdate, -) - -logger = logging.getLogger(__name__) - - -@final -class GraphEngineManager: - """ - Manager for sending control commands to GraphEngine instances. - - This class provides a simple interface for controlling workflow executions - by sending commands through Redis channels, without user validation. - """ - - _redis_client: RedisClientProtocol - - def __init__(self, redis_client: RedisClientProtocol) -> None: - self._redis_client = redis_client - - def send_stop_command(self, task_id: str, reason: str | None = None) -> None: - """ - Send a stop command to a running workflow. - - Args: - task_id: The task ID of the workflow to stop - reason: Optional reason for stopping (defaults to "User requested stop") - """ - abort_command = AbortCommand(reason=reason or "User requested stop") - self._send_command(task_id, abort_command) - - def send_pause_command(self, task_id: str, reason: str | None = None) -> None: - """Send a pause command to a running workflow.""" - - pause_command = PauseCommand(reason=reason or "User requested pause") - self._send_command(task_id, pause_command) - - def send_update_variables_command(self, task_id: str, updates: Sequence[VariableUpdate]) -> None: - """Send a command to update variables in a running workflow.""" - - if not updates: - return - - update_command = UpdateVariablesCommand(updates=updates) - self._send_command(task_id, update_command) - - def _send_command(self, task_id: str, command: GraphEngineCommand) -> None: - """Send a command to the workflow-specific Redis channel.""" - - if not task_id: - return - - channel_key = f"workflow:{task_id}:commands" - channel = RedisChannel(self._redis_client, channel_key) - - try: - channel.send_command(command) - except Exception: - # Silently fail if Redis is unavailable - # The legacy control mechanisms will still work - logger.exception("Failed to send graph engine command %s for task %s", command.__class__.__name__, task_id) diff --git a/api/graphon/graph_engine/orchestration/__init__.py b/api/graphon/graph_engine/orchestration/__init__.py deleted file mode 100644 index de08e942fb..0000000000 --- a/api/graphon/graph_engine/orchestration/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Orchestration subsystem for graph engine. - -This package coordinates the overall execution flow between -different subsystems. -""" - -from .dispatcher import Dispatcher -from .execution_coordinator import ExecutionCoordinator - -__all__ = [ - "Dispatcher", - "ExecutionCoordinator", -] diff --git a/api/graphon/graph_engine/orchestration/dispatcher.py b/api/graphon/graph_engine/orchestration/dispatcher.py deleted file mode 100644 index f75bbee08e..0000000000 --- a/api/graphon/graph_engine/orchestration/dispatcher.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Main dispatcher for processing events from workers. -""" - -import logging -import queue -import threading -import time -from typing import TYPE_CHECKING, final - -from graphon.graph_events import ( - GraphNodeEventBase, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunSucceededEvent, -) - -from ..event_management import EventManager -from .execution_coordinator import ExecutionCoordinator - -if TYPE_CHECKING: - from ..event_management import EventHandler - -logger = logging.getLogger(__name__) - - -@final -class Dispatcher: - """ - Main dispatcher that processes events from the event queue. - - This runs in a separate thread and coordinates event processing - with timeout and completion detection. - """ - - _COMMAND_TRIGGER_EVENTS = ( - NodeRunSucceededEvent, - NodeRunFailedEvent, - NodeRunExceptionEvent, - ) - - def __init__( - self, - event_queue: queue.Queue[GraphNodeEventBase], - event_handler: "EventHandler", - execution_coordinator: ExecutionCoordinator, - event_emitter: EventManager | None = None, - ) -> None: - """ - Initialize the dispatcher. - - Args: - event_queue: Queue of events from workers - event_handler: Event handler registry for processing events - execution_coordinator: Coordinator for execution flow - event_emitter: Optional event manager to signal completion - """ - self._event_queue = event_queue - self._event_handler = event_handler - self._execution_coordinator = execution_coordinator - self._event_emitter = event_emitter - - self._thread: threading.Thread | None = None - self._stop_event = threading.Event() - self._start_time: float | None = None - - def start(self) -> None: - """Start the dispatcher thread.""" - if self._thread and self._thread.is_alive(): - return - - self._stop_event.clear() - self._start_time = time.time() - self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True) - self._thread.start() - - def stop(self) -> None: - """Stop the dispatcher thread.""" - self._stop_event.set() - if self._thread and self._thread.is_alive(): - self._thread.join(timeout=2.0) - - def _dispatcher_loop(self) -> None: - """Main dispatcher loop.""" - try: - self._process_commands() - paused = False - while not self._stop_event.is_set(): - if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete: - break - if self._execution_coordinator.paused: - paused = True - break - - self._execution_coordinator.check_scaling() - try: - event = self._event_queue.get(timeout=0.1) - self._event_handler.dispatch(event) - self._event_queue.task_done() - self._process_commands(event) - except queue.Empty: - time.sleep(0.1) - - self._process_commands() - if paused: - self._drain_events_until_idle() - else: - self._drain_event_queue() - - except Exception as e: - logger.exception("Dispatcher error") - self._execution_coordinator.mark_failed(e) - - finally: - self._execution_coordinator.mark_complete() - # Signal the event emitter that execution is complete - if self._event_emitter: - self._event_emitter.mark_complete() - - def _process_commands(self, event: GraphNodeEventBase | None = None): - if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS): - self._execution_coordinator.process_commands() - - def _drain_event_queue(self) -> None: - while True: - try: - event = self._event_queue.get(block=False) - self._event_handler.dispatch(event) - self._event_queue.task_done() - except queue.Empty: - break - - def _drain_events_until_idle(self) -> None: - while not self._stop_event.is_set(): - try: - event = self._event_queue.get(timeout=0.1) - self._event_handler.dispatch(event) - self._event_queue.task_done() - self._process_commands(event) - except queue.Empty: - if not self._execution_coordinator.has_executing_nodes(): - break - self._drain_event_queue() diff --git a/api/graphon/graph_engine/orchestration/execution_coordinator.py b/api/graphon/graph_engine/orchestration/execution_coordinator.py deleted file mode 100644 index 0f8550eb12..0000000000 --- a/api/graphon/graph_engine/orchestration/execution_coordinator.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -Execution coordinator for managing overall workflow execution. -""" - -from typing import final - -from ..command_processing import CommandProcessor -from ..domain import GraphExecution -from ..graph_state_manager import GraphStateManager -from ..worker_management import WorkerPool - - -@final -class ExecutionCoordinator: - """ - Coordinates overall execution flow between subsystems. - - This provides high-level coordination methods used by the - dispatcher to manage execution state. - """ - - def __init__( - self, - graph_execution: GraphExecution, - state_manager: GraphStateManager, - command_processor: CommandProcessor, - worker_pool: WorkerPool, - ) -> None: - """ - Initialize the execution coordinator. - - Args: - graph_execution: Graph execution aggregate - state_manager: Unified state manager - command_processor: Processor for commands - worker_pool: Pool of workers - """ - self._graph_execution = graph_execution - self._state_manager = state_manager - self._command_processor = command_processor - self._worker_pool = worker_pool - - def process_commands(self) -> None: - """Process any pending commands.""" - self._command_processor.process_commands() - - def check_scaling(self) -> None: - """Check and perform worker scaling if needed.""" - self._worker_pool.check_and_scale() - - @property - def execution_complete(self): - return self._state_manager.is_execution_complete() - - @property - def aborted(self): - return self._graph_execution.aborted or self._graph_execution.has_error - - @property - def paused(self) -> bool: - """Expose whether the underlying graph execution is paused.""" - return self._graph_execution.is_paused - - def mark_complete(self) -> None: - """Mark execution as complete.""" - if self._graph_execution.is_paused: - return - if not self._graph_execution.completed: - self._graph_execution.complete() - - def mark_failed(self, error: Exception) -> None: - """ - Mark execution as failed. - - Args: - error: The error that caused failure - """ - self._graph_execution.fail(error) - - def handle_pause_if_needed(self) -> None: - """If the execution has been paused, stop workers immediately.""" - - if not self._graph_execution.is_paused: - return - - self._worker_pool.stop() - self._state_manager.clear_executing() - - def handle_abort_if_needed(self) -> None: - """If the execution has been aborted, stop workers immediately.""" - - if not self._graph_execution.aborted: - return - - self._worker_pool.stop() - self._state_manager.clear_executing() - - def has_executing_nodes(self) -> bool: - """Return True if any nodes are currently marked as executing.""" - # This check is only safe once execution has already paused. - # Before pause, executing state can change concurrently, which makes the result unreliable. - if not self._graph_execution.is_paused: - raise AssertionError("has_executing_nodes should only be called after execution is paused") - return self._state_manager.get_executing_count() > 0 diff --git a/api/graphon/graph_engine/protocols/command_channel.py b/api/graphon/graph_engine/protocols/command_channel.py deleted file mode 100644 index fabd8634c8..0000000000 --- a/api/graphon/graph_engine/protocols/command_channel.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -CommandChannel protocol for GraphEngine command communication. - -This protocol defines the interface for sending and receiving commands -to/from a GraphEngine instance, supporting both local and distributed scenarios. -""" - -from typing import Protocol - -from ..entities.commands import GraphEngineCommand - - -class CommandChannel(Protocol): - """ - Protocol for bidirectional command communication with GraphEngine. - - Since each GraphEngine instance processes only one workflow execution, - this channel is dedicated to that single execution. - """ - - def fetch_commands(self) -> list[GraphEngineCommand]: - """ - Fetch pending commands for this GraphEngine instance. - - Called by GraphEngine to poll for commands that need to be processed. - - Returns: - List of pending commands (may be empty) - """ - ... - - def send_command(self, command: GraphEngineCommand) -> None: - """ - Send a command to be processed by this GraphEngine instance. - - Called by external systems to send control commands to the running workflow. - - Args: - command: The command to send - """ - ... diff --git a/api/graphon/graph_engine/ready_queue/__init__.py b/api/graphon/graph_engine/ready_queue/__init__.py deleted file mode 100644 index acba0e961c..0000000000 --- a/api/graphon/graph_engine/ready_queue/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Ready queue implementations for GraphEngine. - -This package contains the protocol and implementations for managing -the queue of nodes ready for execution. -""" - -from .factory import create_ready_queue_from_state -from .in_memory import InMemoryReadyQueue -from .protocol import ReadyQueue, ReadyQueueState - -__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"] diff --git a/api/graphon/graph_engine/ready_queue/factory.py b/api/graphon/graph_engine/ready_queue/factory.py deleted file mode 100644 index a9d4f470e5..0000000000 --- a/api/graphon/graph_engine/ready_queue/factory.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -Factory for creating ReadyQueue instances from serialized state. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from .in_memory import InMemoryReadyQueue -from .protocol import ReadyQueueState - -if TYPE_CHECKING: - from .protocol import ReadyQueue - - -def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue: - """ - Create a ReadyQueue instance from a serialized state. - - Args: - state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue - - Returns: - A ReadyQueue instance initialized with the given state - - Raises: - ValueError: If the queue type is unknown or version is unsupported - """ - if state.type == "InMemoryReadyQueue": - if state.version != "1.0": - raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}") - queue = InMemoryReadyQueue() - # Always pass as JSON string to loads() - queue.loads(state.model_dump_json()) - return queue - else: - raise ValueError(f"Unknown ready queue type: {state.type}") diff --git a/api/graphon/graph_engine/ready_queue/in_memory.py b/api/graphon/graph_engine/ready_queue/in_memory.py deleted file mode 100644 index f2c265ece0..0000000000 --- a/api/graphon/graph_engine/ready_queue/in_memory.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -In-memory implementation of the ReadyQueue protocol. - -This implementation wraps Python's standard queue.Queue and adds -serialization capabilities for state storage. -""" - -import queue -from typing import final - -from .protocol import ReadyQueue, ReadyQueueState - - -@final -class InMemoryReadyQueue(ReadyQueue): - """ - In-memory ready queue implementation with serialization support. - - This implementation uses Python's queue.Queue internally and provides - methods to serialize and restore the queue state. - """ - - def __init__(self, maxsize: int = 0) -> None: - """ - Initialize the in-memory ready queue. - - Args: - maxsize: Maximum size of the queue (0 for unlimited) - """ - self._queue: queue.Queue[str] = queue.Queue(maxsize=maxsize) - - def put(self, item: str) -> None: - """ - Add a node ID to the ready queue. - - Args: - item: The node ID to add to the queue - """ - self._queue.put(item) - - def get(self, timeout: float | None = None) -> str: - """ - Retrieve and remove a node ID from the queue. - - Args: - timeout: Maximum time to wait for an item (None for blocking) - - Returns: - The node ID retrieved from the queue - - Raises: - queue.Empty: If timeout expires and no item is available - """ - if timeout is None: - return self._queue.get(block=True) - return self._queue.get(timeout=timeout) - - def task_done(self) -> None: - """ - Indicate that a previously retrieved task is complete. - - Used by worker threads to signal task completion for - join() synchronization. - """ - self._queue.task_done() - - def empty(self) -> bool: - """ - Check if the queue is empty. - - Returns: - True if the queue has no items, False otherwise - """ - return self._queue.empty() - - def qsize(self) -> int: - """ - Get the approximate size of the queue. - - Returns: - The approximate number of items in the queue - """ - return self._queue.qsize() - - def dumps(self) -> str: - """ - Serialize the queue state to a JSON string for storage. - - Returns: - A JSON string containing the serialized queue state - """ - # Extract all items from the queue without removing them - items: list[str] = [] - temp_items: list[str] = [] - - # Drain the queue temporarily to get all items - while not self._queue.empty(): - try: - item = self._queue.get_nowait() - temp_items.append(item) - items.append(item) - except queue.Empty: - break - - # Put items back in the same order - for item in temp_items: - self._queue.put(item) - - state = ReadyQueueState( - type="InMemoryReadyQueue", - version="1.0", - items=items, - ) - return state.model_dump_json() - - def loads(self, data: str) -> None: - """ - Restore the queue state from a JSON string. - - Args: - data: The JSON string containing the serialized queue state to restore - """ - state = ReadyQueueState.model_validate_json(data) - - if state.type != "InMemoryReadyQueue": - raise ValueError(f"Invalid serialized data type: {state.type}") - - if state.version != "1.0": - raise ValueError(f"Unsupported version: {state.version}") - - # Clear the current queue - while not self._queue.empty(): - try: - self._queue.get_nowait() - except queue.Empty: - break - - # Restore items - for item in state.items: - self._queue.put(item) diff --git a/api/graphon/graph_engine/ready_queue/protocol.py b/api/graphon/graph_engine/ready_queue/protocol.py deleted file mode 100644 index 97d3ea6dd2..0000000000 --- a/api/graphon/graph_engine/ready_queue/protocol.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -ReadyQueue protocol for GraphEngine node execution queue. - -This protocol defines the interface for managing the queue of nodes ready -for execution, supporting both in-memory and persistent storage scenarios. -""" - -from collections.abc import Sequence -from typing import Protocol - -from pydantic import BaseModel, Field - - -class ReadyQueueState(BaseModel): - """ - Pydantic model for serialized ready queue state. - - This defines the structure of the data returned by dumps() - and expected by loads() for ready queue serialization. - """ - - type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')") - version: str = Field(description="Serialization format version") - items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue") - - -class ReadyQueue(Protocol): - """ - Protocol for managing nodes ready for execution in GraphEngine. - - This protocol defines the interface that any ready queue implementation - must provide, enabling both in-memory queues and persistent queues - that can be serialized for state storage. - """ - - def put(self, item: str) -> None: - """ - Add a node ID to the ready queue. - - Args: - item: The node ID to add to the queue - """ - ... - - def get(self, timeout: float | None = None) -> str: - """ - Retrieve and remove a node ID from the queue. - - Args: - timeout: Maximum time to wait for an item (None for blocking) - - Returns: - The node ID retrieved from the queue - - Raises: - queue.Empty: If timeout expires and no item is available - """ - ... - - def task_done(self) -> None: - """ - Indicate that a previously retrieved task is complete. - - Used by worker threads to signal task completion for - join() synchronization. - """ - ... - - def empty(self) -> bool: - """ - Check if the queue is empty. - - Returns: - True if the queue has no items, False otherwise - """ - ... - - def qsize(self) -> int: - """ - Get the approximate size of the queue. - - Returns: - The approximate number of items in the queue - """ - ... - - def dumps(self) -> str: - """ - Serialize the queue state to a JSON string for storage. - - Returns: - A JSON string containing the serialized queue state - that can be persisted and later restored - """ - ... - - def loads(self, data: str) -> None: - """ - Restore the queue state from a JSON string. - - Args: - data: The JSON string containing the serialized queue state to restore - """ - ... diff --git a/api/graphon/graph_engine/response_coordinator/__init__.py b/api/graphon/graph_engine/response_coordinator/__init__.py deleted file mode 100644 index e11d31199c..0000000000 --- a/api/graphon/graph_engine/response_coordinator/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -ResponseStreamCoordinator - Coordinates streaming output from response nodes - -This component manages response streaming sessions and ensures ordered streaming -of responses based on upstream node outputs and constants. -""" - -from .coordinator import ResponseStreamCoordinator - -__all__ = ["ResponseStreamCoordinator"] diff --git a/api/graphon/graph_engine/response_coordinator/coordinator.py b/api/graphon/graph_engine/response_coordinator/coordinator.py deleted file mode 100644 index a6562f0223..0000000000 --- a/api/graphon/graph_engine/response_coordinator/coordinator.py +++ /dev/null @@ -1,697 +0,0 @@ -""" -Main ResponseStreamCoordinator implementation. - -This module contains the public ResponseStreamCoordinator class that manages -response streaming sessions and ensures ordered streaming of responses. -""" - -import logging -from collections import deque -from collections.abc import Sequence -from threading import RLock -from typing import Literal, TypeAlias, final -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from graphon.enums import NodeExecutionType, NodeState -from graphon.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent -from graphon.nodes.base.template import TextSegment, VariableSegment -from graphon.runtime import VariablePool -from graphon.runtime.graph_runtime_state import GraphProtocol - -from .path import Path -from .session import ResponseSession - -logger = logging.getLogger(__name__) - -# Type definitions -NodeID: TypeAlias = str -EdgeID: TypeAlias = str - - -class ResponseSessionState(BaseModel): - """Serializable representation of a response session.""" - - node_id: str - index: int = Field(default=0, ge=0) - - -class StreamBufferState(BaseModel): - """Serializable representation of buffered stream chunks.""" - - selector: tuple[str, ...] - events: list[NodeRunStreamChunkEvent] = Field(default_factory=list) - - -class StreamPositionState(BaseModel): - """Serializable representation for stream read positions.""" - - selector: tuple[str, ...] - position: int = Field(default=0, ge=0) - - -class ResponseStreamCoordinatorState(BaseModel): - """Serialized snapshot of ResponseStreamCoordinator.""" - - type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator") - version: str = Field(default="1.0") - response_nodes: Sequence[str] = Field(default_factory=list) - active_session: ResponseSessionState | None = None - waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list) - pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list) - node_execution_ids: dict[str, str] = Field(default_factory=dict) - paths_map: dict[str, list[list[str]]] = Field(default_factory=dict) - stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list) - stream_positions: Sequence[StreamPositionState] = Field(default_factory=list) - closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list) - - -@final -class ResponseStreamCoordinator: - """ - Manages response streaming sessions without relying on global state. - - Ensures ordered streaming of responses based on upstream node outputs and constants. - """ - - def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None: - """ - Initialize coordinator with variable pool. - - Args: - variable_pool: VariablePool instance for accessing node variables - graph: Graph instance for looking up node information - """ - self._variable_pool = variable_pool - self._graph = graph - self._active_session: ResponseSession | None = None - self._waiting_sessions: deque[ResponseSession] = deque() - self._lock = RLock() - - # Internal stream management (replacing OutputRegistry) - self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {} - self._stream_positions: dict[tuple[str, ...], int] = {} - self._closed_streams: set[tuple[str, ...]] = set() - - # Track response nodes - self._response_nodes: set[NodeID] = set() - - # Store paths for each response node - self._paths_maps: dict[NodeID, list[Path]] = {} - - # Track node execution IDs and types for proper event forwarding - self._node_execution_ids: dict[NodeID, str] = {} # node_id -> execution_id - - # Track response sessions to ensure only one per node - self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session - - def register(self, response_node_id: NodeID) -> None: - with self._lock: - if response_node_id in self._response_nodes: - return - self._response_nodes.add(response_node_id) - - # Build and save paths map for this response node - paths_map = self._build_paths_map(response_node_id) - self._paths_maps[response_node_id] = paths_map - - # Create and store response session for this node - response_node = self._graph.nodes[response_node_id] - session = ResponseSession.from_node(response_node) - self._response_sessions[response_node_id] = session - - def track_node_execution(self, node_id: NodeID, execution_id: str) -> None: - """Track the execution ID for a node when it starts executing. - - Args: - node_id: The ID of the node - execution_id: The execution ID from NodeRunStartedEvent - """ - with self._lock: - self._node_execution_ids[node_id] = execution_id - - def _get_or_create_execution_id(self, node_id: NodeID) -> str: - """Get the execution ID for a node, creating one if it doesn't exist. - - Args: - node_id: The ID of the node - - Returns: - The execution ID for the node - """ - with self._lock: - if node_id not in self._node_execution_ids: - self._node_execution_ids[node_id] = str(uuid4()) - return self._node_execution_ids[node_id] - - def _build_paths_map(self, response_node_id: NodeID) -> list[Path]: - """ - Build a paths map for a response node by finding all paths from root node - to the response node, recording branch edges along each path. - - Args: - response_node_id: ID of the response node to analyze - - Returns: - List of Path objects, where each path contains branch edge IDs - """ - # Get root node ID - root_node_id = self._graph.root_node.id - - # If root is the response node, return empty path - if root_node_id == response_node_id: - return [Path()] - - # Extract variable selectors from the response node's template - response_node = self._graph.nodes[response_node_id] - response_session = ResponseSession.from_node(response_node) - template = response_session.template - - # Collect all variable selectors from the template - variable_selectors: set[tuple[str, ...]] = set() - for segment in template.segments: - if isinstance(segment, VariableSegment): - variable_selectors.add(tuple(segment.selector[:2])) - - # Step 1: Find all complete paths from root to response node - all_complete_paths: list[list[EdgeID]] = [] - - def find_paths( - current_node_id: NodeID, target_node_id: NodeID, current_path: list[EdgeID], visited: set[NodeID] - ) -> None: - """Recursively find all paths from current node to target node.""" - if current_node_id == target_node_id: - # Found a complete path, store it - all_complete_paths.append(current_path.copy()) - return - - # Mark as visited to avoid cycles - visited.add(current_node_id) - - # Explore outgoing edges - outgoing_edges = self._graph.get_outgoing_edges(current_node_id) - for edge in outgoing_edges: - edge_id = edge.id - next_node_id = edge.head - - # Skip if already visited in this path - if next_node_id not in visited: - # Add edge to path and recurse - new_path = current_path + [edge_id] - find_paths(next_node_id, target_node_id, new_path, visited.copy()) - - # Start searching from root node - find_paths(root_node_id, response_node_id, [], set()) - - # Step 2: For each complete path, filter edges based on node blocking behavior - filtered_paths: list[Path] = [] - for path in all_complete_paths: - blocking_edges: list[str] = [] - for edge_id in path: - edge = self._graph.edges[edge_id] - source_node = self._graph.nodes[edge.tail] - - # Check if node is a branch, container, or response node - if source_node.execution_type in { - NodeExecutionType.BRANCH, - NodeExecutionType.CONTAINER, - NodeExecutionType.RESPONSE, - } or source_node.blocks_variable_output(variable_selectors): - blocking_edges.append(edge_id) - - # Keep the path even if it's empty - filtered_paths.append(Path(edges=blocking_edges)) - - return filtered_paths - - def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]: - """ - Handle when an edge is taken (selected by a branch node). - - This method updates the paths for all response nodes by removing - the taken edge. If any response node has an empty path after removal, - it means the node is now deterministically reachable and should start. - - Args: - edge_id: The ID of the edge that was taken - - Returns: - List of events to emit from starting new sessions - """ - events: list[NodeRunStreamChunkEvent] = [] - - with self._lock: - # Check each response node in order - for response_node_id in self._response_nodes: - if response_node_id not in self._paths_maps: - continue - - paths = self._paths_maps[response_node_id] - has_reachable_path = False - - # Update each path by removing the taken edge - for path in paths: - # Remove the taken edge from this path - path.remove_edge(edge_id) - - # Check if this path is now empty (node is reachable) - if path.is_empty(): - has_reachable_path = True - - # If node is now reachable (has empty path), start/queue session - if has_reachable_path: - # Pass the node_id to the activation method - # The method will handle checking and removing from map - events.extend(self._active_or_queue_session(response_node_id)) - return events - - def _active_or_queue_session(self, node_id: str) -> Sequence[NodeRunStreamChunkEvent]: - """ - Start a session immediately if no active session, otherwise queue it. - Only activates sessions that exist in the _response_sessions map. - - Args: - node_id: The ID of the response node to activate - - Returns: - List of events from flush attempt if session started immediately - """ - events: list[NodeRunStreamChunkEvent] = [] - - # Get the session from our map (only activate if it exists) - session = self._response_sessions.get(node_id) - if not session: - return events - - # Remove from map to ensure it won't be activated again - del self._response_sessions[node_id] - - if self._active_session is None: - self._active_session = session - - # Try to flush immediately - events.extend(self.try_flush()) - else: - # Queue the session if another is active - self._waiting_sessions.append(session) - - return events - - def intercept_event( - self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent - ) -> Sequence[NodeRunStreamChunkEvent]: - with self._lock: - if isinstance(event, NodeRunStreamChunkEvent): - self._append_stream_chunk(event.selector, event) - if event.is_final: - self._close_stream(event.selector) - return self.try_flush() - else: - # Skip cause we share the same variable pool. - # - # for variable_name, variable_value in event.node_run_result.outputs.items(): - # self._variable_pool.add((event.node_id, variable_name), variable_value) - return self.try_flush() - - def _create_stream_chunk_event( - self, - node_id: str, - execution_id: str, - selector: Sequence[str], - chunk: str, - is_final: bool = False, - ) -> NodeRunStreamChunkEvent: - """Create a stream chunk event with consistent structure. - - For selectors with special prefixes (sys, env, conversation), we use the - active response node's information since these are not actual node IDs. - """ - # Check if this is a special selector that doesn't correspond to a node - if selector and selector[0] not in self._graph.nodes and self._active_session: - # Use the active response node for special selectors - response_node = self._graph.nodes[self._active_session.node_id] - return NodeRunStreamChunkEvent( - id=execution_id, - node_id=response_node.id, - node_type=response_node.node_type, - selector=selector, - chunk=chunk, - is_final=is_final, - ) - - # Standard case: selector refers to an actual node - node = self._graph.nodes[node_id] - return NodeRunStreamChunkEvent( - id=execution_id, - node_id=node.id, - node_type=node.node_type, - selector=selector, - chunk=chunk, - is_final=is_final, - ) - - def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]: - """Process a variable segment. Returns (events, is_complete). - - Handles both regular node selectors and special system selectors (sys, env, conversation). - For special selectors, we attribute the output to the active response node. - """ - events: list[NodeRunStreamChunkEvent] = [] - source_selector_prefix = segment.selector[0] if segment.selector else "" - is_complete = False - - # Determine which node to attribute the output to - # For special selectors (sys, env, conversation), use the active response node - # For regular selectors, use the source node - if self._active_session and source_selector_prefix not in self._graph.nodes: - # Special selector - use active response node - output_node_id = self._active_session.node_id - else: - # Regular node selector - output_node_id = source_selector_prefix - execution_id = self._get_or_create_execution_id(output_node_id) - - # Stream all available chunks - while self._has_unread_stream(segment.selector): - if event := self._pop_stream_chunk(segment.selector): - # For special selectors, we need to update the event to use - # the active response node's information - if self._active_session and source_selector_prefix not in self._graph.nodes: - response_node = self._graph.nodes[self._active_session.node_id] - # Create a new event with the response node's information - # but keep the original selector - updated_event = NodeRunStreamChunkEvent( - id=execution_id, - node_id=response_node.id, - node_type=response_node.node_type, - selector=event.selector, # Keep original selector - chunk=event.chunk, - is_final=event.is_final, - ) - events.append(updated_event) - else: - # Regular node selector - use event as is - events.append(event) - - # Check if this is the last chunk by looking ahead - stream_closed = self._is_stream_closed(segment.selector) - # Check if stream is closed to determine if segment is complete - if stream_closed: - is_complete = True - - elif value := self._variable_pool.get(segment.selector): - # Process scalar value - is_last_segment = bool( - self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1 - ) - events.append( - self._create_stream_chunk_event( - node_id=output_node_id, - execution_id=execution_id, - selector=segment.selector, - chunk=value.markdown, - is_final=is_last_segment, - ) - ) - is_complete = True - - return events, is_complete - - def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]: - """Process a text segment. Returns (events, is_complete).""" - assert self._active_session is not None - current_response_node = self._graph.nodes[self._active_session.node_id] - - # Use get_or_create_execution_id to ensure we have a consistent ID - execution_id = self._get_or_create_execution_id(current_response_node.id) - - is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1 - event = self._create_stream_chunk_event( - node_id=current_response_node.id, - execution_id=execution_id, - selector=[current_response_node.id, "answer"], # FIXME(-LAN-) - chunk=segment.text, - is_final=is_last_segment, - ) - return [event] - - def try_flush(self) -> list[NodeRunStreamChunkEvent]: - with self._lock: - if not self._active_session: - return [] - - template = self._active_session.template - response_node_id = self._active_session.node_id - - events: list[NodeRunStreamChunkEvent] = [] - - # Process segments sequentially from current index - while self._active_session.index < len(template.segments): - segment = template.segments[self._active_session.index] - - if isinstance(segment, VariableSegment): - # Check if the source node for this variable is skipped - # Only check for actual nodes, not special selectors (sys, env, conversation) - source_selector_prefix = segment.selector[0] if segment.selector else "" - if source_selector_prefix in self._graph.nodes: - source_node = self._graph.nodes[source_selector_prefix] - - if source_node.state == NodeState.SKIPPED: - # Skip this variable segment if the source node is skipped - self._active_session.index += 1 - continue - - segment_events, is_complete = self._process_variable_segment(segment) - events.extend(segment_events) - - # Only advance index if this variable segment is complete - if is_complete: - self._active_session.index += 1 - else: - # Wait for more data - break - - else: - segment_events = self._process_text_segment(segment) - events.extend(segment_events) - self._active_session.index += 1 - - if self._active_session.is_complete(): - # End current session and get events from starting next session - next_session_events = self.end_session(response_node_id) - events.extend(next_session_events) - - return events - - def end_session(self, node_id: str) -> list[NodeRunStreamChunkEvent]: - """ - End the active session for a response node. - Automatically starts the next waiting session if available. - - Args: - node_id: ID of the response node ending its session - - Returns: - List of events from starting the next session - """ - with self._lock: - events: list[NodeRunStreamChunkEvent] = [] - - if self._active_session and self._active_session.node_id == node_id: - self._active_session = None - - # Try to start next waiting session - if self._waiting_sessions: - next_session = self._waiting_sessions.popleft() - self._active_session = next_session - - # Immediately try to flush any available segments - events = self.try_flush() - - return events - - # ============= Internal Stream Management Methods ============= - - def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None: - """ - Append a stream chunk to the internal buffer. - - Args: - selector: List of strings identifying the stream location - event: The NodeRunStreamChunkEvent to append - - Raises: - ValueError: If the stream is already closed - """ - key = tuple(selector) - - if key in self._closed_streams: - raise ValueError(f"Stream {'.'.join(selector)} is already closed") - - if key not in self._stream_buffers: - self._stream_buffers[key] = [] - self._stream_positions[key] = 0 - - self._stream_buffers[key].append(event) - - def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None: - """ - Pop the next unread stream chunk from the buffer. - - Args: - selector: List of strings identifying the stream location - - Returns: - The next event, or None if no unread events available - """ - key = tuple(selector) - - if key not in self._stream_buffers: - return None - - position = self._stream_positions.get(key, 0) - buffer = self._stream_buffers[key] - - if position >= len(buffer): - return None - - event = buffer[position] - self._stream_positions[key] = position + 1 - return event - - def _has_unread_stream(self, selector: Sequence[str]) -> bool: - """ - Check if the stream has unread events. - - Args: - selector: List of strings identifying the stream location - - Returns: - True if there are unread events, False otherwise - """ - key = tuple(selector) - - if key not in self._stream_buffers: - return False - - position = self._stream_positions.get(key, 0) - return position < len(self._stream_buffers[key]) - - def _close_stream(self, selector: Sequence[str]) -> None: - """ - Mark a stream as closed (no more chunks can be appended). - - Args: - selector: List of strings identifying the stream location - """ - key = tuple(selector) - self._closed_streams.add(key) - - def _is_stream_closed(self, selector: Sequence[str]) -> bool: - """ - Check if a stream is closed. - - Args: - selector: List of strings identifying the stream location - - Returns: - True if the stream is closed, False otherwise - """ - key = tuple(selector) - return key in self._closed_streams - - def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None: - """Convert an in-memory session into its serializable form.""" - - if session is None: - return None - return ResponseSessionState(node_id=session.node_id, index=session.index) - - def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession: - """Rebuild a response session from serialized data.""" - - node = self._graph.nodes.get(session_state.node_id) - if node is None: - raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state") - - session = ResponseSession.from_node(node) - session.index = session_state.index - return session - - def dumps(self) -> str: - """Serialize coordinator state to JSON.""" - - with self._lock: - state = ResponseStreamCoordinatorState( - response_nodes=sorted(self._response_nodes), - active_session=self._serialize_session(self._active_session), - waiting_sessions=[ - session_state - for session in list(self._waiting_sessions) - if (session_state := self._serialize_session(session)) is not None - ], - pending_sessions=[ - session_state - for _, session in sorted(self._response_sessions.items()) - if (session_state := self._serialize_session(session)) is not None - ], - node_execution_ids=dict(sorted(self._node_execution_ids.items())), - paths_map={ - node_id: [path.edges.copy() for path in paths] - for node_id, paths in sorted(self._paths_maps.items()) - }, - stream_buffers=[ - StreamBufferState( - selector=selector, - events=[event.model_copy(deep=True) for event in events], - ) - for selector, events in sorted(self._stream_buffers.items()) - ], - stream_positions=[ - StreamPositionState(selector=selector, position=position) - for selector, position in sorted(self._stream_positions.items()) - ], - closed_streams=sorted(self._closed_streams), - ) - return state.model_dump_json() - - def loads(self, data: str) -> None: - """Restore coordinator state from JSON.""" - - state = ResponseStreamCoordinatorState.model_validate_json(data) - - if state.type != "ResponseStreamCoordinator": - raise ValueError(f"Invalid serialized data type: {state.type}") - - if state.version != "1.0": - raise ValueError(f"Unsupported serialized version: {state.version}") - - with self._lock: - self._response_nodes = set(state.response_nodes) - self._paths_maps = { - node_id: [Path(edges=list(path_edges)) for path_edges in paths] - for node_id, paths in state.paths_map.items() - } - self._node_execution_ids = dict(state.node_execution_ids) - - self._stream_buffers = { - tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events] - for buffer in state.stream_buffers - } - self._stream_positions = { - tuple(position.selector): position.position for position in state.stream_positions - } - for selector in self._stream_buffers: - self._stream_positions.setdefault(selector, 0) - - self._closed_streams = {tuple(selector) for selector in state.closed_streams} - - self._waiting_sessions = deque( - self._session_from_state(session_state) for session_state in state.waiting_sessions - ) - self._response_sessions = { - session_state.node_id: self._session_from_state(session_state) - for session_state in state.pending_sessions - } - self._active_session = self._session_from_state(state.active_session) if state.active_session else None diff --git a/api/graphon/graph_engine/response_coordinator/path.py b/api/graphon/graph_engine/response_coordinator/path.py deleted file mode 100644 index 50f2f4eb21..0000000000 --- a/api/graphon/graph_engine/response_coordinator/path.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Internal path representation for response coordinator. - -This module contains the private Path class used internally by ResponseStreamCoordinator -to track execution paths to response nodes. -""" - -from dataclasses import dataclass, field -from typing import TypeAlias - -EdgeID: TypeAlias = str - - -@dataclass -class Path: - """ - Represents a path of branch edges that must be taken to reach a response node. - - Note: This is an internal class not exposed in the public API. - """ - - edges: list[EdgeID] = field(default_factory=list[EdgeID]) - - def contains_edge(self, edge_id: EdgeID) -> bool: - """Check if this path contains the given edge.""" - return edge_id in self.edges - - def remove_edge(self, edge_id: EdgeID) -> None: - """Remove the given edge from this path in place.""" - if self.contains_edge(edge_id): - self.edges.remove(edge_id) - - def is_empty(self) -> bool: - """Check if the path has no edges (node is reachable).""" - return len(self.edges) == 0 diff --git a/api/graphon/graph_engine/response_coordinator/session.py b/api/graphon/graph_engine/response_coordinator/session.py deleted file mode 100644 index cb877f1504..0000000000 --- a/api/graphon/graph_engine/response_coordinator/session.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Internal response session management for response coordinator. - -This module contains the private ResponseSession class used internally -by ResponseStreamCoordinator to manage streaming sessions. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Protocol, cast - -from graphon.nodes.base.template import Template -from graphon.runtime.graph_runtime_state import NodeProtocol - - -class _ResponseSessionNodeProtocol(NodeProtocol, Protocol): - """Structural contract required from nodes that can open a response session.""" - - def get_streaming_template(self) -> Template: ... - - -@dataclass -class ResponseSession: - """ - Represents an active response streaming session. - - Note: This is an internal class not exposed in the public API. - """ - - node_id: str - template: Template # Template object from the response node - index: int = 0 # Current position in the template segments - - @classmethod - def from_node(cls, node: NodeProtocol) -> ResponseSession: - """ - Create a ResponseSession from a response-capable node. - - The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer. - At runtime this must be a node that implements `get_streaming_template()`. The coordinator decides which - graph nodes should be treated as response-capable before they reach this factory. - - Args: - node: Node from the materialized workflow graph. - - Returns: - ResponseSession configured with the node's streaming template - - Raises: - TypeError: If node does not implement the response-session streaming contract. - """ - response_node = cast(_ResponseSessionNodeProtocol, node) - try: - template = response_node.get_streaming_template() - except AttributeError as exc: - raise TypeError("ResponseSession.from_node requires get_streaming_template() on response nodes") from exc - - return cls( - node_id=node.id, - template=template, - ) - - def is_complete(self) -> bool: - """Check if all segments in the template have been processed.""" - return self.index >= len(self.template.segments) diff --git a/api/graphon/graph_engine/worker.py b/api/graphon/graph_engine/worker.py deleted file mode 100644 index a0844ee48e..0000000000 --- a/api/graphon/graph_engine/worker.py +++ /dev/null @@ -1,204 +0,0 @@ -""" -Worker - Thread implementation for queue-based node execution - -Workers pull node IDs from the ready_queue, execute nodes, and push events -to the event_queue for the dispatcher to process. -""" - -import queue -import threading -import time -from collections.abc import Sequence -from contextlib import AbstractContextManager -from datetime import UTC, datetime -from typing import TYPE_CHECKING, final - -from typing_extensions import override - -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node - -from .ready_queue import ReadyQueue - -if TYPE_CHECKING: - pass - - -@final -class Worker(threading.Thread): - """ - Worker thread that executes nodes from the ready queue. - - Workers continuously pull node IDs from the ready_queue, execute the - corresponding nodes, and push the resulting events to the event_queue - for the dispatcher to process. - """ - - def __init__( - self, - ready_queue: ReadyQueue, - event_queue: queue.Queue[GraphNodeEventBase], - graph: Graph, - layers: Sequence[GraphEngineLayer], - worker_id: int = 0, - execution_context: AbstractContextManager[object] | None = None, - ) -> None: - """ - Initialize worker thread. - - Args: - ready_queue: Ready queue containing node IDs ready for execution - event_queue: Queue for pushing execution events - graph: Graph containing nodes to execute - layers: Graph engine layers for node execution hooks - worker_id: Unique identifier for this worker - execution_context: Optional execution context for context preservation - """ - super().__init__(name=f"GraphWorker-{worker_id}", daemon=True) - self._ready_queue = ready_queue - self._event_queue = event_queue - self._graph = graph - self._worker_id = worker_id - self._execution_context = execution_context - self._stop_event = threading.Event() - self._layers = layers if layers is not None else [] - self._last_task_time = time.time() - self._current_node_started_at: datetime | None = None - - def stop(self) -> None: - """Signal the worker to stop processing.""" - self._stop_event.set() - - @property - def is_idle(self) -> bool: - """Check if the worker is currently idle.""" - # Worker is idle if it hasn't processed a task recently (within 0.2 seconds) - return (time.time() - self._last_task_time) > 0.2 - - @property - def idle_duration(self) -> float: - """Get the duration in seconds since the worker last processed a task.""" - return time.time() - self._last_task_time - - @property - def worker_id(self) -> int: - """Get the worker's ID.""" - return self._worker_id - - @override - def run(self) -> None: - """ - Main worker loop. - - Continuously pulls node IDs from ready_queue, executes them, - and pushes events to event_queue until stopped. - """ - while not self._stop_event.is_set(): - # Try to get a node ID from the ready queue (with timeout) - try: - node_id = self._ready_queue.get(timeout=0.1) - except queue.Empty: - continue - - self._last_task_time = time.time() - node = self._graph.nodes[node_id] - try: - self._current_node_started_at = None - self._execute_node(node) - self._ready_queue.task_done() - except Exception as e: - self._event_queue.put( - self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at) - ) - finally: - self._current_node_started_at = None - - def _execute_node(self, node: Node) -> None: - """ - Execute a single node and handle its events. - - Args: - node: The node instance to execute - """ - node.ensure_execution_id() - - error: Exception | None = None - result_event: GraphNodeEventBase | None = None - - # Execute the node with preserved context if execution context is provided - if self._execution_context is not None: - with self._execution_context: - self._invoke_node_run_start_hooks(node) - try: - node_events = node.run() - for event in node_events: - if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id: - self._current_node_started_at = event.start_at - self._event_queue.put(event) - if is_node_result_event(event): - result_event = event - except Exception as exc: - error = exc - raise - finally: - self._invoke_node_run_end_hooks(node, error, result_event) - else: - self._invoke_node_run_start_hooks(node) - try: - node_events = node.run() - for event in node_events: - if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id: - self._current_node_started_at = event.start_at - self._event_queue.put(event) - if is_node_result_event(event): - result_event = event - except Exception as exc: - error = exc - raise - finally: - self._invoke_node_run_end_hooks(node, error, result_event) - - def _invoke_node_run_start_hooks(self, node: Node) -> None: - """Invoke on_node_run_start hooks for all layers.""" - for layer in self._layers: - try: - layer.on_node_run_start(node) - except Exception: - # Silently ignore layer errors to prevent disrupting node execution - continue - - def _invoke_node_run_end_hooks( - self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None - ) -> None: - """Invoke on_node_run_end hooks for all layers.""" - for layer in self._layers: - try: - layer.on_node_run_end(node, error, result_event) - except Exception: - # Silently ignore layer errors to prevent disrupting node execution - continue - - def _build_fallback_failure_event( - self, node: Node, error: Exception, *, started_at: datetime | None = None - ) -> NodeRunFailedEvent: - """Build a failed event when worker-level execution aborts before a node emits its own result event.""" - failure_time = datetime.now(UTC).replace(tzinfo=None) - error_message = str(error) - return NodeRunFailedEvent( - id=node.execution_id, - node_id=node.id, - node_type=node.node_type, - in_iteration_id=None, - error=error_message, - start_at=started_at or failure_time, - finished_at=failure_time, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=error_message, - error_type=type(error).__name__, - ), - ) diff --git a/api/graphon/graph_engine/worker_management/__init__.py b/api/graphon/graph_engine/worker_management/__init__.py deleted file mode 100644 index 03de1f6daa..0000000000 --- a/api/graphon/graph_engine/worker_management/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Worker management subsystem for graph engine. - -This package manages the worker pool, including creation, -scaling, and activity tracking. -""" - -from .worker_pool import WorkerPool - -__all__ = [ - "WorkerPool", -] diff --git a/api/graphon/graph_engine/worker_management/worker_pool.py b/api/graphon/graph_engine/worker_management/worker_pool.py deleted file mode 100644 index 85cdf1ca21..0000000000 --- a/api/graphon/graph_engine/worker_management/worker_pool.py +++ /dev/null @@ -1,277 +0,0 @@ -""" -Simple worker pool that consolidates functionality. - -This is a simpler implementation that merges WorkerPool, ActivityTracker, -DynamicScaler, and WorkerFactory into a single class. -""" - -import logging -import queue -import threading -from contextlib import AbstractContextManager -from typing import final - -from graphon.graph import Graph -from graphon.graph_events import GraphNodeEventBase - -from ..config import GraphEngineConfig -from ..layers.base import GraphEngineLayer -from ..ready_queue import ReadyQueue -from ..worker import Worker - -logger = logging.getLogger(__name__) - - -@final -class WorkerPool: - """ - Simple worker pool with integrated management. - - This class consolidates all worker management functionality into - a single, simpler implementation without excessive abstraction. - """ - - def __init__( - self, - ready_queue: ReadyQueue, - event_queue: queue.Queue[GraphNodeEventBase], - graph: Graph, - layers: list[GraphEngineLayer], - config: GraphEngineConfig, - execution_context: AbstractContextManager[object] | None = None, - ) -> None: - """ - Initialize the simple worker pool. - - Args: - ready_queue: Ready queue for nodes ready for execution - event_queue: Queue for worker events - graph: The workflow graph - layers: Graph engine layers for node execution hooks - config: GraphEngine worker pool configuration - execution_context: Optional execution context for context preservation - """ - self._ready_queue = ready_queue - self._event_queue = event_queue - self._graph = graph - self._execution_context = execution_context - self._layers = layers - self._config = config - - # Worker management - self._workers: list[Worker] = [] - self._worker_counter = 0 - self._lock = threading.RLock() - self._running = False - - # No longer tracking worker states with callbacks to avoid lock contention - - def start(self, initial_count: int | None = None) -> None: - """ - Start the worker pool. - - Args: - initial_count: Number of workers to start with (auto-calculated if None) - """ - with self._lock: - if self._running: - return - - self._running = True - - # Calculate initial worker count - if initial_count is None: - node_count = len(self._graph.nodes) - if node_count < 10: - initial_count = self._config.min_workers - elif node_count < 50: - initial_count = min(self._config.min_workers + 1, self._config.max_workers) - else: - initial_count = min(self._config.min_workers + 2, self._config.max_workers) - - logger.debug( - "Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)", - initial_count, - node_count, - self._config.min_workers, - self._config.max_workers, - ) - - # Create initial workers - for _ in range(initial_count): - self._create_worker() - - def stop(self) -> None: - """Stop all workers in the pool.""" - with self._lock: - self._running = False - worker_count = len(self._workers) - - if worker_count > 0: - logger.debug("Stopping worker pool: %d workers", worker_count) - - # Stop all workers - for worker in self._workers: - worker.stop() - - # Wait for workers to finish - for worker in self._workers: - if worker.is_alive(): - worker.join(timeout=2.0) - - self._workers.clear() - - def _create_worker(self) -> None: - """Create and start a new worker.""" - worker_id = self._worker_counter - self._worker_counter += 1 - - worker = Worker( - ready_queue=self._ready_queue, - event_queue=self._event_queue, - graph=self._graph, - layers=self._layers, - worker_id=worker_id, - execution_context=self._execution_context, - ) - - worker.start() - self._workers.append(worker) - - def _remove_worker(self, worker: Worker, worker_id: int) -> None: - """Remove a specific worker from the pool.""" - # Stop the worker - worker.stop() - - # Wait for it to finish - if worker.is_alive(): - worker.join(timeout=2.0) - - # Remove from list - if worker in self._workers: - self._workers.remove(worker) - - def _try_scale_up(self, queue_depth: int, current_count: int) -> bool: - """ - Try to scale up workers if needed. - - Args: - queue_depth: Current queue depth - current_count: Current number of workers - - Returns: - True if scaled up, False otherwise - """ - if queue_depth > self._config.scale_up_threshold and current_count < self._config.max_workers: - old_count = current_count - self._create_worker() - - logger.debug( - "Scaled up workers: %d -> %d (queue_depth=%d exceeded threshold=%d)", - old_count, - len(self._workers), - queue_depth, - self._config.scale_up_threshold, - ) - return True - return False - - def _try_scale_down(self, queue_depth: int, current_count: int, active_count: int, idle_count: int) -> bool: - """ - Try to scale down workers if we have excess capacity. - - Args: - queue_depth: Current queue depth - current_count: Current number of workers - active_count: Number of active workers - idle_count: Number of idle workers - - Returns: - True if scaled down, False otherwise - """ - # Skip if we're at minimum or have no idle workers - if current_count <= self._config.min_workers or idle_count == 0: - return False - - # Check if we have excess capacity - has_excess_capacity = ( - queue_depth <= active_count # Active workers can handle current queue - or idle_count > active_count # More idle than active workers - or (queue_depth == 0 and idle_count > 0) # No work and have idle workers - ) - - if not has_excess_capacity: - return False - - # Find and remove idle workers that have been idle long enough - workers_to_remove: list[tuple[Worker, int]] = [] - - for worker in self._workers: - # Check if worker is idle and has exceeded idle time threshold - if worker.is_idle and worker.idle_duration >= self._config.scale_down_idle_time: - # Don't remove if it would leave us unable to handle the queue - remaining_workers = current_count - len(workers_to_remove) - 1 - if remaining_workers >= self._config.min_workers and remaining_workers >= max(1, queue_depth // 2): - workers_to_remove.append((worker, worker.worker_id)) - # Only remove one worker per check to avoid aggressive scaling - break - - # Remove idle workers if any found - if workers_to_remove: - old_count = current_count - for worker, worker_id in workers_to_remove: - self._remove_worker(worker, worker_id) - - logger.debug( - "Scaled down workers: %d -> %d (removed %d idle workers after %.1fs, " - "queue_depth=%d, active=%d, idle=%d)", - old_count, - len(self._workers), - len(workers_to_remove), - self._config.scale_down_idle_time, - queue_depth, - active_count, - idle_count - len(workers_to_remove), - ) - return True - - return False - - def check_and_scale(self) -> None: - """Check and perform scaling if needed.""" - with self._lock: - if not self._running: - return - - current_count = len(self._workers) - queue_depth = self._ready_queue.qsize() - - # Count active vs idle workers by querying their state directly - idle_count = sum(1 for worker in self._workers if worker.is_idle) - active_count = current_count - idle_count - - # Try to scale up if queue is backing up - self._try_scale_up(queue_depth, current_count) - - # Try to scale down if we have excess capacity - self._try_scale_down(queue_depth, current_count, active_count, idle_count) - - def get_worker_count(self) -> int: - """Get current number of workers.""" - with self._lock: - return len(self._workers) - - def get_status(self) -> dict[str, int]: - """ - Get pool status information. - - Returns: - Dictionary with status information - """ - with self._lock: - return { - "total_workers": len(self._workers), - "queue_depth": self._ready_queue.qsize(), - "min_workers": self._config.min_workers, - "max_workers": self._config.max_workers, - } diff --git a/api/graphon/graph_events/__init__.py b/api/graphon/graph_events/__init__.py deleted file mode 100644 index 7cec587a05..0000000000 --- a/api/graphon/graph_events/__init__.py +++ /dev/null @@ -1,84 +0,0 @@ -# Agent events -from .agent import NodeRunAgentLogEvent - -# Base events -from .base import ( - BaseGraphEvent, - GraphEngineEvent, - GraphNodeEventBase, -) - -# Graph events -from .graph import ( - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) - -# Iteration events -from .iteration import ( - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, -) - -# Loop events -from .loop import ( - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, -) - -# Node events -from .node import ( - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunPauseRequestedEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, - is_node_result_event, -) - -__all__ = [ - "BaseGraphEvent", - "GraphEngineEvent", - "GraphNodeEventBase", - "GraphRunAbortedEvent", - "GraphRunFailedEvent", - "GraphRunPartialSucceededEvent", - "GraphRunPausedEvent", - "GraphRunStartedEvent", - "GraphRunSucceededEvent", - "NodeRunAgentLogEvent", - "NodeRunExceptionEvent", - "NodeRunFailedEvent", - "NodeRunHumanInputFormFilledEvent", - "NodeRunHumanInputFormTimeoutEvent", - "NodeRunIterationFailedEvent", - "NodeRunIterationNextEvent", - "NodeRunIterationStartedEvent", - "NodeRunIterationSucceededEvent", - "NodeRunLoopFailedEvent", - "NodeRunLoopNextEvent", - "NodeRunLoopStartedEvent", - "NodeRunLoopSucceededEvent", - "NodeRunPauseRequestedEvent", - "NodeRunRetrieverResourceEvent", - "NodeRunRetryEvent", - "NodeRunStartedEvent", - "NodeRunStreamChunkEvent", - "NodeRunSucceededEvent", - "NodeRunVariableUpdatedEvent", - "is_node_result_event", -] diff --git a/api/graphon/graph_events/agent.py b/api/graphon/graph_events/agent.py deleted file mode 100644 index 759fe3a71c..0000000000 --- a/api/graphon/graph_events/agent.py +++ /dev/null @@ -1,17 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import Field - -from .base import GraphAgentNodeEventBase - - -class NodeRunAgentLogEvent(GraphAgentNodeEventBase): - message_id: str = Field(..., description="message id") - label: str = Field(..., description="label") - node_execution_id: str = Field(..., description="node execution id") - parent_id: str | None = Field(..., description="parent id") - error: str | None = Field(..., description="error") - status: str = Field(..., description="status") - data: Mapping[str, Any] = Field(..., description="data") - metadata: Mapping[str, object] = Field(default_factory=dict) diff --git a/api/graphon/graph_events/base.py b/api/graphon/graph_events/base.py deleted file mode 100644 index 4ea9787b9a..0000000000 --- a/api/graphon/graph_events/base.py +++ /dev/null @@ -1,31 +0,0 @@ -from pydantic import BaseModel, Field - -from graphon.enums import NodeType -from graphon.node_events import NodeRunResult - - -class GraphEngineEvent(BaseModel): - pass - - -class BaseGraphEvent(GraphEngineEvent): - pass - - -class GraphNodeEventBase(GraphEngineEvent): - id: str = Field(..., description="node execution id") - node_id: str - node_type: NodeType - - in_iteration_id: str | None = None - """iteration id if node is in iteration""" - in_loop_id: str | None = None - """loop id if node is in loop""" - - # The version of the node, or "1" if not specified. - node_version: str = "1" - node_run_result: NodeRunResult = Field(default_factory=NodeRunResult) - - -class GraphAgentNodeEventBase(GraphNodeEventBase): - pass diff --git a/api/graphon/graph_events/graph.py b/api/graphon/graph_events/graph.py deleted file mode 100644 index 3782cb49bc..0000000000 --- a/api/graphon/graph_events/graph.py +++ /dev/null @@ -1,57 +0,0 @@ -from pydantic import Field - -from graphon.entities.pause_reason import PauseReason -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.graph_events import BaseGraphEvent - - -class GraphRunStartedEvent(BaseGraphEvent): - # Reason is emitted for workflow start events and is always set. - reason: WorkflowStartReason = Field( - default=WorkflowStartReason.INITIAL, - description="reason for workflow start", - ) - - -class GraphRunSucceededEvent(BaseGraphEvent): - """Event emitted when a run completes successfully with final outputs.""" - - outputs: dict[str, object] = Field( - default_factory=dict, - description="Final workflow outputs keyed by output selector.", - ) - - -class GraphRunFailedEvent(BaseGraphEvent): - error: str = Field(..., description="failed reason") - exceptions_count: int = Field(description="exception count", default=0) - - -class GraphRunPartialSucceededEvent(BaseGraphEvent): - """Event emitted when a run finishes with partial success and failures.""" - - exceptions_count: int = Field(..., description="exception count") - outputs: dict[str, object] = Field( - default_factory=dict, - description="Outputs that were materialised before failures occurred.", - ) - - -class GraphRunAbortedEvent(BaseGraphEvent): - """Event emitted when a graph run is aborted by user command.""" - - reason: str | None = Field(default=None, description="reason for abort") - outputs: dict[str, object] = Field( - default_factory=dict, - description="Outputs produced before the abort was requested.", - ) - - -class GraphRunPausedEvent(BaseGraphEvent): - """Event emitted when a graph run is paused by user command.""" - - reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list) - outputs: dict[str, object] = Field( - default_factory=dict, - description="Outputs available to the client while the run is paused.", - ) diff --git a/api/graphon/graph_events/human_input.py b/api/graphon/graph_events/human_input.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/graph_events/iteration.py b/api/graphon/graph_events/iteration.py deleted file mode 100644 index 28627395fd..0000000000 --- a/api/graphon/graph_events/iteration.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import GraphNodeEventBase - - -class NodeRunIterationStartedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class NodeRunIterationNextEvent(GraphNodeEventBase): - node_title: str - index: int = Field(..., description="index") - pre_iteration_output: Any = None - - -class NodeRunIterationSucceededEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class NodeRunIterationFailedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/graphon/graph_events/loop.py b/api/graphon/graph_events/loop.py deleted file mode 100644 index 7cdc5427e2..0000000000 --- a/api/graphon/graph_events/loop.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import GraphNodeEventBase - - -class NodeRunLoopStartedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class NodeRunLoopNextEvent(GraphNodeEventBase): - node_title: str - index: int = Field(..., description="index") - pre_loop_output: Any = None - - -class NodeRunLoopSucceededEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class NodeRunLoopFailedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/graphon/graph_events/node.py b/api/graphon/graph_events/node.py deleted file mode 100644 index 471ae08ee7..0000000000 --- a/api/graphon/graph_events/node.py +++ /dev/null @@ -1,106 +0,0 @@ -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any - -from pydantic import Field - -from graphon.entities.pause_reason import PauseReason -from graphon.variables.variables import Variable - -from .base import GraphNodeEventBase - - -class NodeRunStartedEvent(GraphNodeEventBase): - node_title: str - predecessor_node_id: str | None = None - start_at: datetime = Field(..., description="node start time") - extras: dict[str, object] = Field(default_factory=dict) - - # FIXME(-LAN-): only for ToolNode - provider_type: str = "" - provider_id: str = "" - - -class NodeRunStreamChunkEvent(GraphNodeEventBase): - # Spec-compliant fields - selector: Sequence[str] = Field( - ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" - ) - chunk: str = Field(..., description="the actual chunk content") - is_final: bool = Field(default=False, description="indicates if this is the last chunk") - - -class NodeRunRetrieverResourceEvent(GraphNodeEventBase): - retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources") - context: str = Field(..., description="context") - - -class NodeRunSucceededEvent(GraphNodeEventBase): - start_at: datetime = Field(..., description="node start time") - finished_at: datetime | None = Field(default=None, description="node finish time") - - -class NodeRunVariableUpdatedEvent(GraphNodeEventBase): - """Request that the engine apply a variable update before downstream observers continue.""" - - variable: Variable = Field(..., description="Updated variable payload to apply.") - - -class NodeRunFailedEvent(GraphNodeEventBase): - error: str = Field(..., description="error") - start_at: datetime = Field(..., description="node start time") - finished_at: datetime | None = Field(default=None, description="node finish time") - - -class NodeRunExceptionEvent(GraphNodeEventBase): - error: str = Field(..., description="error") - start_at: datetime = Field(..., description="node start time") - finished_at: datetime | None = Field(default=None, description="node finish time") - - -class NodeRunRetryEvent(NodeRunStartedEvent): - error: str = Field(..., description="error") - retry_index: int = Field(..., description="which retry attempt is about to be performed") - - -class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase): - """Emitted when a HumanInput form is submitted and before the node finishes.""" - - node_title: str = Field(..., description="HumanInput node title") - rendered_content: str = Field(..., description="Markdown content rendered with user inputs.") - action_id: str = Field(..., description="User action identifier chosen in the form.") - action_text: str = Field(..., description="Display text of the chosen action button.") - - -class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase): - """Emitted when a HumanInput form times out.""" - - node_title: str = Field(..., description="HumanInput node title") - expiration_time: datetime = Field(..., description="Form expiration time") - - -class NodeRunPauseRequestedEvent(GraphNodeEventBase): - reason: PauseReason = Field(..., description="pause reason") - - -def is_node_result_event(event: GraphNodeEventBase) -> bool: - """ - Check if an event is a final result event from node execution. - - A result event indicates the completion of a node execution and contains - runtime information such as inputs, outputs, or error details. - - Args: - event: The event to check - - Returns: - True if the event is a node result event (succeeded/failed/paused), False otherwise - """ - return isinstance( - event, - ( - NodeRunSucceededEvent, - NodeRunFailedEvent, - NodeRunPauseRequestedEvent, - ), - ) diff --git a/api/graphon/model_runtime/README.md b/api/graphon/model_runtime/README.md deleted file mode 100644 index b9d2c55210..0000000000 --- a/api/graphon/model_runtime/README.md +++ /dev/null @@ -1,51 +0,0 @@ -# Model Runtime - -This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers. - -- On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers, -- On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic. - -## Features - -- Supports capability invocation for 6 types of models - - - `LLM` - LLM text completion, dialogue, pre-computed tokens capability - - `Text Embedding Model` - Text Embedding, pre-computed tokens capability - - `Rerank Model` - Segment Rerank capability - - `Speech-to-text Model` - Speech to text capability - - `Text-to-speech Model` - Text to speech capability - - `Moderation` - Moderation capability - -- Model provider display - - Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. - -- Selectable model list display - - After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models. - - In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models. - -- Provider/model credential authentication - - The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. - -## Structure - -Model Runtime is divided into three layers: - -- The outermost layer is the factory method - - It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials. - -- The second layer is the provider layer - - It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers. - -- The bottom layer is the model layer - - It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types). - -## Documentation - -For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/). diff --git a/api/graphon/model_runtime/README_CN.md b/api/graphon/model_runtime/README_CN.md deleted file mode 100644 index 0a8b56b3fe..0000000000 --- a/api/graphon/model_runtime/README_CN.md +++ /dev/null @@ -1,64 +0,0 @@ -# Model Runtime - -该模块提供了各模型的调用、鉴权接口,并为 Dify 提供了统一的模型供应商的信息和凭据表单规则。 - -- 一方面将模型和上下游解耦,方便开发者对模型横向扩展, -- 另一方面提供了只需在后端定义供应商和模型,即可在前端页面直接展示,无需修改前端逻辑。 - -## 功能介绍 - -- 支持 6 种模型类型的能力调用 - - - `LLM` - LLM 文本补全、对话,预计算 tokens 能力 - - `Text Embedding Model` - 文本 Embedding,预计算 tokens 能力 - - `Rerank Model` - 分段 Rerank 能力 - - `Speech-to-text Model` - 语音转文本能力 - - `Text-to-speech Model` - 文本转语音能力 - - `Moderation` - Moderation 能力 - -- 模型供应商展示 - - 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。 - -- 可选择的模型列表展示 - - 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 - - 除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。 - -- 供应商/模型凭据鉴权 - - 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。 - -## 结构 - -Model Runtime 分三层: - -- 最外层为工厂方法 - - 提供获取所有供应商、所有模型列表、获取供应商实例、供应商/模型凭据鉴权方法。 - -- 第二层为供应商层 - - 提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。 - - 对于供应商/模型凭据,有两种情况 - - - 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据 - - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 - - 当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。 - -- 最底层为模型层 - - 提供各种模型类型的直接调用、预定义模型配置信息、获取预定义/远程模型列表、模型凭据鉴权方法,不同模型额外提供了特殊方法,如 LLM 提供预计算 tokens 方法、获取费用信息方法等,**可横向扩展**同供应商下不同的模型(支持的模型类型下)。 - - 在这里我们需要先区分模型参数与模型凭据。 - - - 模型参数 (**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在 DifyRuntime 中,他们的参数名一般为**model_parameters: dict[str, any]**。 - - - 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。 - -## 文档 - -有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。 diff --git a/api/graphon/model_runtime/callbacks/base_callback.py b/api/graphon/model_runtime/callbacks/base_callback.py deleted file mode 100644 index cd85cf6301..0000000000 --- a/api/graphon/model_runtime/callbacks/base_callback.py +++ /dev/null @@ -1,159 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence - -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - -_TEXT_COLOR_MAPPING = { - "blue": "36;1", - "yellow": "33;1", - "pink": "38;5;200", - "green": "32;1", - "red": "31;1", -} - - -class Callback(ABC): - """ - Base class for callbacks. - Only for LLM. - """ - - raise_error: bool = False - - @abstractmethod - def on_before_invoke( - self, - llm_instance: AIModel, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - Before invoke callback - - :param llm_instance: LLM instance - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: optional end-user identifier for the invocation - :param invocation_context: opaque request metadata for the current invocation - """ - raise NotImplementedError() - - @abstractmethod - def on_new_chunk( - self, - llm_instance: AIModel, - chunk: LLMResultChunk, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - On new chunk callback - - :param llm_instance: LLM instance - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: optional end-user identifier for the invocation - :param invocation_context: opaque request metadata for the current invocation - """ - raise NotImplementedError() - - @abstractmethod - def on_after_invoke( - self, - llm_instance: AIModel, - result: LLMResult, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - After invoke callback - - :param llm_instance: LLM instance - :param result: result - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: optional end-user identifier for the invocation - :param invocation_context: opaque request metadata for the current invocation - """ - raise NotImplementedError() - - @abstractmethod - def on_invoke_error( - self, - llm_instance: AIModel, - ex: Exception, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - Invoke error callback - - :param llm_instance: LLM instance - :param ex: exception - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: optional end-user identifier for the invocation - :param invocation_context: opaque request metadata for the current invocation - """ - raise NotImplementedError() - - def print_text(self, text: str, color: str | None = None, end: str = ""): - """Print text with highlighting and no end characters.""" - text_to_print = self._get_colored_text(text, color) if color else text - print(text_to_print, end=end) - - def _get_colored_text(self, text: str, color: str) -> str: - """Get colored text.""" - color_str = _TEXT_COLOR_MAPPING[color] - return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" diff --git a/api/graphon/model_runtime/callbacks/logging_callback.py b/api/graphon/model_runtime/callbacks/logging_callback.py deleted file mode 100644 index f96eb446fc..0000000000 --- a/api/graphon/model_runtime/callbacks/logging_callback.py +++ /dev/null @@ -1,180 +0,0 @@ -import json -import logging -import sys -from collections.abc import Mapping, Sequence -from typing import cast - -from graphon.model_runtime.callbacks.base_callback import Callback -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -class LoggingCallback(Callback): - def on_before_invoke( - self, - llm_instance: AIModel, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - Before invoke callback - - :param llm_instance: LLM instance - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: optional end-user identifier for the invocation - :param invocation_context: opaque request metadata for the current invocation - """ - self.print_text("\n[on_llm_before_invoke]\n", color="blue") - self.print_text(f"Model: {model}\n", color="blue") - self.print_text("Parameters:\n", color="blue") - for key, value in model_parameters.items(): - self.print_text(f"\t{key}: {value}\n", color="blue") - - if stop: - self.print_text(f"\tstop: {stop}\n", color="blue") - - if tools: - self.print_text("\tTools:\n", color="blue") - for tool in tools: - self.print_text(f"\t\t{tool.name}\n", color="blue") - - self.print_text(f"Stream: {stream}\n", color="blue") - if user: - self.print_text(f"User: {user}\n", color="blue") - - if invocation_context: - self.print_text(f"Invocation context: {dict(invocation_context)}\n", color="blue") - - self.print_text("Prompt messages:\n", color="blue") - for prompt_message in prompt_messages: - if prompt_message.name: - self.print_text(f"\tname: {prompt_message.name}\n", color="blue") - - self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue") - self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue") - - if stream: - self.print_text("\n[on_llm_new_chunk]") - - def on_new_chunk( - self, - llm_instance: AIModel, - chunk: LLMResultChunk, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - On new chunk callback - - :param llm_instance: LLM instance - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - """ - _ = user, invocation_context - sys.stdout.write(cast(str, chunk.delta.message.content)) - sys.stdout.flush() - - def on_after_invoke( - self, - llm_instance: AIModel, - result: LLMResult, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - After invoke callback - - :param llm_instance: LLM instance - :param result: result - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - """ - _ = user, invocation_context - self.print_text("\n[on_llm_after_invoke]\n", color="yellow") - self.print_text(f"Content: {result.message.content}\n", color="yellow") - - if result.message.tool_calls: - self.print_text("Tool calls:\n", color="yellow") - for tool_call in result.message.tool_calls: - self.print_text(f"\t{tool_call.id}\n", color="yellow") - self.print_text(f"\t{tool_call.function.name}\n", color="yellow") - self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow") - - self.print_text(f"Model: {result.model}\n", color="yellow") - self.print_text(f"Usage: {result.usage}\n", color="yellow") - self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow") - - def on_invoke_error( - self, - llm_instance: AIModel, - ex: Exception, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - invocation_context: Mapping[str, object] | None = None, - ): - """ - Invoke error callback - - :param llm_instance: LLM instance - :param ex: exception - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - """ - _ = user, invocation_context - self.print_text("\n[on_llm_invoke_error]\n", color="red") - logger.exception(ex) diff --git a/api/graphon/model_runtime/entities/__init__.py b/api/graphon/model_runtime/entities/__init__.py deleted file mode 100644 index a24e437d48..0000000000 --- a/api/graphon/model_runtime/entities/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -from .llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from .message_entities import ( - AssistantPromptMessage, - AudioPromptMessageContent, - DocumentPromptMessageContent, - ImagePromptMessageContent, - MultiModalPromptMessageContent, - PromptMessage, - PromptMessageContent, - PromptMessageContentType, - PromptMessageRole, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, - VideoPromptMessageContent, -) -from .model_entities import ModelPropertyKey - -__all__ = [ - "AssistantPromptMessage", - "AudioPromptMessageContent", - "DocumentPromptMessageContent", - "ImagePromptMessageContent", - "LLMMode", - "LLMResult", - "LLMResultChunk", - "LLMResultChunkDelta", - "LLMUsage", - "ModelPropertyKey", - "MultiModalPromptMessageContent", - "PromptMessage", - "PromptMessageContent", - "PromptMessageContentType", - "PromptMessageRole", - "PromptMessageTool", - "SystemPromptMessage", - "TextPromptMessageContent", - "ToolPromptMessage", - "UserPromptMessage", - "VideoPromptMessageContent", -] diff --git a/api/graphon/model_runtime/entities/common_entities.py b/api/graphon/model_runtime/entities/common_entities.py deleted file mode 100644 index b673efae22..0000000000 --- a/api/graphon/model_runtime/entities/common_entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from pydantic import BaseModel, model_validator - - -class I18nObject(BaseModel): - """ - Model class for i18n object. - """ - - zh_Hans: str | None = None - en_US: str - - @model_validator(mode="after") - def _(self): - if not self.zh_Hans: - self.zh_Hans = self.en_US - return self diff --git a/api/graphon/model_runtime/entities/defaults.py b/api/graphon/model_runtime/entities/defaults.py deleted file mode 100644 index bcce17c5d5..0000000000 --- a/api/graphon/model_runtime/entities/defaults.py +++ /dev/null @@ -1,130 +0,0 @@ -from graphon.model_runtime.entities.model_entities import DefaultParameterName - -PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { - DefaultParameterName.TEMPERATURE: { - "label": { - "en_US": "Temperature", - "zh_Hans": "温度", - }, - "type": "float", - "help": { - "en_US": "Controls randomness. Lower temperature results in less random completions." - " As the temperature approaches zero, the model will become deterministic and repetitive." - " Higher temperature results in more random completions.", - "zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。" - "较高的温度会导致更多的随机完成。", - }, - "required": False, - "default": 0.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.TOP_P: { - "label": { - "en_US": "Top P", - "zh_Hans": "Top P", - }, - "type": "float", - "help": { - "en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options" - " are considered.", - "zh_Hans": "通过核心采样控制多样性:0.5 表示考虑了一半的所有可能性加权选项。", - }, - "required": False, - "default": 1.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.TOP_K: { - "label": { - "en_US": "Top K", - "zh_Hans": "Top K", - }, - "type": "int", - "help": { - "en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.", - "zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。", - }, - "required": False, - "default": 50, - "min": 1, - "max": 100, - "precision": 0, - }, - DefaultParameterName.PRESENCE_PENALTY: { - "label": { - "en_US": "Presence Penalty", - "zh_Hans": "存在惩罚", - }, - "type": "float", - "help": { - "en_US": "Applies a penalty to the log-probability of tokens already in the text.", - "zh_Hans": "对文本中已有的标记的对数概率施加惩罚。", - }, - "required": False, - "default": 0.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.FREQUENCY_PENALTY: { - "label": { - "en_US": "Frequency Penalty", - "zh_Hans": "频率惩罚", - }, - "type": "float", - "help": { - "en_US": "Applies a penalty to the log-probability of tokens that appear in the text.", - "zh_Hans": "对文本中出现的标记的对数概率施加惩罚。", - }, - "required": False, - "default": 0.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.MAX_TOKENS: { - "label": { - "en_US": "Max Tokens", - "zh_Hans": "最大 Token 数", - }, - "type": "int", - "help": { - "en_US": "Specifies the upper limit on the length of generated results." - " If the generated results are truncated, you can increase this parameter.", - "zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。", - }, - "required": False, - "default": 64, - "min": 1, - "max": 2048, - "precision": 0, - }, - DefaultParameterName.RESPONSE_FORMAT: { - "label": { - "en_US": "Response Format", - "zh_Hans": "回复格式", - }, - "type": "string", - "help": { - "en_US": "Set a response format, ensure the output from llm is a valid code block as possible," - " such as JSON, XML, etc.", - "zh_Hans": "设置一个返回格式,确保 llm 的输出尽可能是有效的代码块,如 JSON、XML 等", - }, - "required": False, - "options": ["JSON", "XML"], - }, - DefaultParameterName.JSON_SCHEMA: { - "label": { - "en_US": "JSON Schema", - }, - "type": "text", - "help": { - "en_US": "Set a response json schema will ensure LLM to adhere it.", - "zh_Hans": "设置返回的 json schema,llm 将按照它返回", - }, - "required": False, - }, -} diff --git a/api/graphon/model_runtime/entities/llm_entities.py b/api/graphon/model_runtime/entities/llm_entities.py deleted file mode 100644 index bfc80f21c5..0000000000 --- a/api/graphon/model_runtime/entities/llm_entities.py +++ /dev/null @@ -1,219 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from decimal import Decimal -from enum import StrEnum -from typing import Any, TypedDict, Union - -from pydantic import BaseModel, Field - -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage -from graphon.model_runtime.entities.model_entities import ModelUsage, PriceInfo - - -class LLMMode(StrEnum): - """ - Enum class for large language model mode. - """ - - COMPLETION = "completion" - CHAT = "chat" - - -class LLMUsageMetadata(TypedDict, total=False): - """ - TypedDict for LLM usage metadata. - All fields are optional. - """ - - prompt_tokens: int - completion_tokens: int - total_tokens: int - prompt_unit_price: Union[float, str] - completion_unit_price: Union[float, str] - total_price: Union[float, str] - currency: str - prompt_price_unit: Union[float, str] - completion_price_unit: Union[float, str] - prompt_price: Union[float, str] - completion_price: Union[float, str] - latency: float - time_to_first_token: float - time_to_generate: float - - -class LLMUsage(ModelUsage): - """ - Model class for llm usage. - """ - - prompt_tokens: int - prompt_unit_price: Decimal - prompt_price_unit: Decimal - prompt_price: Decimal - completion_tokens: int - completion_unit_price: Decimal - completion_price_unit: Decimal - completion_price: Decimal - total_tokens: int - total_price: Decimal - currency: str - latency: float - time_to_first_token: float | None = None - time_to_generate: float | None = None - - @classmethod - def empty_usage(cls): - return cls( - prompt_tokens=0, - prompt_unit_price=Decimal("0.0"), - prompt_price_unit=Decimal("0.0"), - prompt_price=Decimal("0.0"), - completion_tokens=0, - completion_unit_price=Decimal("0.0"), - completion_price_unit=Decimal("0.0"), - completion_price=Decimal("0.0"), - total_tokens=0, - total_price=Decimal("0.0"), - currency="USD", - latency=0.0, - time_to_first_token=None, - time_to_generate=None, - ) - - @classmethod - def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage: - """ - Create LLMUsage instance from metadata dictionary with default values. - - Args: - metadata: TypedDict containing usage metadata - - Returns: - LLMUsage instance with values from metadata or defaults - """ - prompt_tokens = metadata.get("prompt_tokens", 0) - completion_tokens = metadata.get("completion_tokens", 0) - total_tokens = metadata.get("total_tokens", 0) - - # If total_tokens is not provided but prompt and completion tokens are, - # calculate total_tokens - if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0): - total_tokens = prompt_tokens + completion_tokens - - return cls( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))), - completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))), - total_price=Decimal(str(metadata.get("total_price", 0))), - currency=metadata.get("currency", "USD"), - prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))), - completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))), - prompt_price=Decimal(str(metadata.get("prompt_price", 0))), - completion_price=Decimal(str(metadata.get("completion_price", 0))), - latency=metadata.get("latency", 0.0), - time_to_first_token=metadata.get("time_to_first_token"), - time_to_generate=metadata.get("time_to_generate"), - ) - - def plus(self, other: LLMUsage) -> LLMUsage: - """ - Add two LLMUsage instances together. - - :param other: Another LLMUsage instance to add - :return: A new LLMUsage instance with summed values - """ - if self.total_tokens == 0: - return other - else: - return LLMUsage( - prompt_tokens=self.prompt_tokens + other.prompt_tokens, - prompt_unit_price=other.prompt_unit_price, - prompt_price_unit=other.prompt_price_unit, - prompt_price=self.prompt_price + other.prompt_price, - completion_tokens=self.completion_tokens + other.completion_tokens, - completion_unit_price=other.completion_unit_price, - completion_price_unit=other.completion_price_unit, - completion_price=self.completion_price + other.completion_price, - total_tokens=self.total_tokens + other.total_tokens, - total_price=self.total_price + other.total_price, - currency=other.currency, - latency=self.latency + other.latency, - time_to_first_token=other.time_to_first_token, - time_to_generate=other.time_to_generate, - ) - - def __add__(self, other: LLMUsage) -> LLMUsage: - """ - Overload the + operator to add two LLMUsage instances. - - :param other: Another LLMUsage instance to add - :return: A new LLMUsage instance with summed values - """ - return self.plus(other) - - -class LLMResult(BaseModel): - """ - Model class for llm result. - """ - - id: str | None = None - model: str - prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) - message: AssistantPromptMessage - usage: LLMUsage - system_fingerprint: str | None = None - reasoning_content: str | None = None - - -class LLMStructuredOutput(BaseModel): - """ - Model class for llm structured output. - """ - - structured_output: Mapping[str, Any] | None = None - - -class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput): - """ - Model class for llm result with structured output. - """ - - -class LLMResultChunkDelta(BaseModel): - """ - Model class for llm result chunk delta. - """ - - index: int - message: AssistantPromptMessage - usage: LLMUsage | None = None - finish_reason: str | None = None - - -class LLMResultChunk(BaseModel): - """ - Model class for llm result chunk. - """ - - model: str - prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) - system_fingerprint: str | None = None - delta: LLMResultChunkDelta - - -class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput): - """ - Model class for llm result chunk with structured output. - """ - - -class NumTokensResult(PriceInfo): - """ - Model class for number of tokens result. - """ - - tokens: int diff --git a/api/graphon/model_runtime/entities/message_entities.py b/api/graphon/model_runtime/entities/message_entities.py deleted file mode 100644 index 402bfdc606..0000000000 --- a/api/graphon/model_runtime/entities/message_entities.py +++ /dev/null @@ -1,279 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from collections.abc import Mapping, Sequence -from enum import StrEnum, auto -from typing import Annotated, Any, Literal, Union - -from pydantic import BaseModel, Field, field_serializer, field_validator - - -class PromptMessageRole(StrEnum): - """ - Enum class for prompt message. - """ - - SYSTEM = auto() - USER = auto() - ASSISTANT = auto() - TOOL = auto() - - @classmethod - def value_of(cls, value: str) -> PromptMessageRole: - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid prompt message type value {value}") - - -class PromptMessageTool(BaseModel): - """ - Model class for prompt message tool. - """ - - name: str - description: str - parameters: dict - - -class PromptMessageFunction(BaseModel): - """ - Model class for prompt message function. - """ - - type: str = "function" - function: PromptMessageTool - - -class PromptMessageContentType(StrEnum): - """ - Enum class for prompt message content type. - """ - - TEXT = auto() - IMAGE = auto() - AUDIO = auto() - VIDEO = auto() - DOCUMENT = auto() - - -class PromptMessageContent(ABC, BaseModel): - """ - Model class for prompt message content. - """ - - type: PromptMessageContentType - - -class TextPromptMessageContent(PromptMessageContent): - """ - Model class for text prompt message content. - """ - - type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore - data: str - - -class MultiModalPromptMessageContent(PromptMessageContent): - """ - Model class for multi-modal prompt message content. - """ - - format: str = Field(default=..., description="the format of multi-modal file") - base64_data: str = Field(default="", description="the base64 data of multi-modal file") - url: str = Field(default="", description="the url of multi-modal file") - mime_type: str = Field(default=..., description="the mime type of multi-modal file") - filename: str = Field(default="", description="the filename of multi-modal file") - - @property - def data(self): - return self.url or f"data:{self.mime_type};base64,{self.base64_data}" - - -class VideoPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore - - -class AudioPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore - - -class ImagePromptMessageContent(MultiModalPromptMessageContent): - """ - Model class for image prompt message content. - """ - - class DETAIL(StrEnum): - LOW = auto() - HIGH = auto() - - type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore - detail: DETAIL = DETAIL.LOW - - -class DocumentPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore - - -PromptMessageContentUnionTypes = Annotated[ - Union[ - TextPromptMessageContent, - ImagePromptMessageContent, - DocumentPromptMessageContent, - AudioPromptMessageContent, - VideoPromptMessageContent, - ], - Field(discriminator="type"), -] - - -CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = { - PromptMessageContentType.TEXT: TextPromptMessageContent, - PromptMessageContentType.IMAGE: ImagePromptMessageContent, - PromptMessageContentType.AUDIO: AudioPromptMessageContent, - PromptMessageContentType.VIDEO: VideoPromptMessageContent, - PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent, -} - - -class PromptMessage(ABC, BaseModel): - """ - Model class for prompt message. - """ - - role: PromptMessageRole - content: str | list[PromptMessageContentUnionTypes] | None = None - name: str | None = None - - def is_empty(self) -> bool: - """ - Check if prompt message is empty. - - :return: True if prompt message is empty, False otherwise - """ - return not self.content - - def get_text_content(self) -> str: - """ - Get text content from prompt message. - - :return: Text content as string, empty string if no text content - """ - if isinstance(self.content, str): - return self.content - elif isinstance(self.content, list): - text_parts = [] - for item in self.content: - if isinstance(item, TextPromptMessageContent): - text_parts.append(item.data) - return "".join(text_parts) - else: - return "" - - @field_validator("content", mode="before") - @classmethod - def validate_content(cls, v): - if isinstance(v, list): - prompts = [] - for prompt in v: - if isinstance(prompt, PromptMessageContent): - if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent): - prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump()) - elif isinstance(prompt, dict): - prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt) - else: - raise ValueError(f"invalid prompt message {prompt}") - prompts.append(prompt) - return prompts - return v - - @field_serializer("content") - def serialize_content( - self, content: Union[str, Sequence[PromptMessageContent]] | None - ) -> str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent] | None: - if content is None or isinstance(content, str): - return content - if isinstance(content, list): - return [item.model_dump() if hasattr(item, "model_dump") else item for item in content] - return content - - -class UserPromptMessage(PromptMessage): - """ - Model class for user prompt message. - """ - - role: PromptMessageRole = PromptMessageRole.USER - - -class AssistantPromptMessage(PromptMessage): - """ - Model class for assistant prompt message. - """ - - class ToolCall(BaseModel): - """ - Model class for assistant prompt message tool call. - """ - - class ToolCallFunction(BaseModel): - """ - Model class for assistant prompt message tool call function. - """ - - name: str - arguments: str - - id: str - type: str - function: ToolCallFunction - - @field_validator("id", mode="before") - @classmethod - def transform_id_to_str(cls, value) -> str: - if not isinstance(value, str): - return str(value) - else: - return value - - role: PromptMessageRole = PromptMessageRole.ASSISTANT - tool_calls: list[ToolCall] = [] - - def is_empty(self) -> bool: - """ - Check if prompt message is empty. - - :return: True if prompt message is empty, False otherwise - """ - return super().is_empty() and not self.tool_calls - - -class SystemPromptMessage(PromptMessage): - """ - Model class for system prompt message. - """ - - role: PromptMessageRole = PromptMessageRole.SYSTEM - - -class ToolPromptMessage(PromptMessage): - """ - Model class for tool prompt message. - """ - - role: PromptMessageRole = PromptMessageRole.TOOL - tool_call_id: str - - def is_empty(self) -> bool: - """ - Check if prompt message is empty. - - :return: True if prompt message is empty, False otherwise - """ - return super().is_empty() and not self.tool_call_id diff --git a/api/graphon/model_runtime/entities/model_entities.py b/api/graphon/model_runtime/entities/model_entities.py deleted file mode 100644 index 5ec4970faf..0000000000 --- a/api/graphon/model_runtime/entities/model_entities.py +++ /dev/null @@ -1,242 +0,0 @@ -from __future__ import annotations - -from decimal import Decimal -from enum import StrEnum, auto -from typing import Any - -from pydantic import BaseModel, ConfigDict, model_validator - -from graphon.model_runtime.entities.common_entities import I18nObject - - -class ModelType(StrEnum): - """ - Enum class for model type. - """ - - LLM = auto() - TEXT_EMBEDDING = "text-embedding" - RERANK = auto() - SPEECH2TEXT = auto() - MODERATION = auto() - TTS = auto() - - @classmethod - def value_of(cls, origin_model_type: str) -> ModelType: - """ - Get model type from origin model type. - - :return: model type - """ - if origin_model_type in {"text-generation", cls.LLM}: - return cls.LLM - elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}: - return cls.TEXT_EMBEDDING - elif origin_model_type in {"reranking", cls.RERANK}: - return cls.RERANK - elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}: - return cls.SPEECH2TEXT - elif origin_model_type in {"tts", cls.TTS}: - return cls.TTS - elif origin_model_type == cls.MODERATION: - return cls.MODERATION - else: - raise ValueError(f"invalid origin model type {origin_model_type}") - - def to_origin_model_type(self) -> str: - """ - Get origin model type from model type. - - :return: origin model type - """ - if self == self.LLM: - return "text-generation" - elif self == self.TEXT_EMBEDDING: - return "embeddings" - elif self == self.RERANK: - return "reranking" - elif self == self.SPEECH2TEXT: - return "speech2text" - elif self == self.TTS: - return "tts" - elif self == self.MODERATION: - return "moderation" - else: - raise ValueError(f"invalid model type {self}") - - -class FetchFrom(StrEnum): - """ - Enum class for fetch from. - """ - - PREDEFINED_MODEL = "predefined-model" - CUSTOMIZABLE_MODEL = "customizable-model" - - -class ModelFeature(StrEnum): - """ - Enum class for llm feature. - """ - - TOOL_CALL = "tool-call" - MULTI_TOOL_CALL = "multi-tool-call" - AGENT_THOUGHT = "agent-thought" - VISION = auto() - STREAM_TOOL_CALL = "stream-tool-call" - DOCUMENT = auto() - VIDEO = auto() - AUDIO = auto() - STRUCTURED_OUTPUT = "structured-output" - - -class DefaultParameterName(StrEnum): - """ - Enum class for parameter template variable. - """ - - TEMPERATURE = auto() - TOP_P = auto() - TOP_K = auto() - PRESENCE_PENALTY = auto() - FREQUENCY_PENALTY = auto() - MAX_TOKENS = auto() - RESPONSE_FORMAT = auto() - JSON_SCHEMA = auto() - - @classmethod - def value_of(cls, value: Any) -> DefaultParameterName: - """ - Get parameter name from value. - - :param value: parameter value - :return: parameter name - """ - for name in cls: - if name.value == value: - return name - raise ValueError(f"invalid parameter name {value}") - - -class ParameterType(StrEnum): - """ - Enum class for parameter type. - """ - - FLOAT = auto() - INT = auto() - STRING = auto() - BOOLEAN = auto() - TEXT = auto() - - -class ModelPropertyKey(StrEnum): - """ - Enum class for model property key. - """ - - MODE = auto() - CONTEXT_SIZE = auto() - MAX_CHUNKS = auto() - FILE_UPLOAD_LIMIT = auto() - SUPPORTED_FILE_EXTENSIONS = auto() - MAX_CHARACTERS_PER_CHUNK = auto() - DEFAULT_VOICE = auto() - VOICES = auto() - WORD_LIMIT = auto() - AUDIO_TYPE = auto() - MAX_WORKERS = auto() - - -class ProviderModel(BaseModel): - """ - Model class for provider model. - """ - - model: str - label: I18nObject - model_type: ModelType - features: list[ModelFeature] | None = None - fetch_from: FetchFrom - model_properties: dict[ModelPropertyKey, Any] - deprecated: bool = False - model_config = ConfigDict(protected_namespaces=()) - - @property - def support_structure_output(self) -> bool: - return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features - - -class ParameterRule(BaseModel): - """ - Model class for parameter rule. - """ - - name: str - use_template: str | None = None - label: I18nObject - type: ParameterType - help: I18nObject | None = None - required: bool = False - default: Any | None = None - min: float | None = None - max: float | None = None - precision: int | None = None - options: list[str] = [] - - -class PriceConfig(BaseModel): - """ - Model class for pricing info. - """ - - input: Decimal - output: Decimal | None = None - unit: Decimal - currency: str - - -class AIModelEntity(ProviderModel): - """ - Model class for AI model. - """ - - parameter_rules: list[ParameterRule] = [] - pricing: PriceConfig | None = None - - @model_validator(mode="after") - def validate_model(self): - supported_schema_keys = ["json_schema"] - schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None) - if not schema_key: - return self - if self.features is None: - self.features = [ModelFeature.STRUCTURED_OUTPUT] - else: - if ModelFeature.STRUCTURED_OUTPUT not in self.features: - self.features.append(ModelFeature.STRUCTURED_OUTPUT) - return self - - -class ModelUsage(BaseModel): - pass - - -class PriceType(StrEnum): - """ - Enum class for price type. - """ - - INPUT = auto() - OUTPUT = auto() - - -class PriceInfo(BaseModel): - """ - Model class for price info. - """ - - unit_price: Decimal - unit: Decimal - total_amount: Decimal - currency: str diff --git a/api/graphon/model_runtime/entities/provider_entities.py b/api/graphon/model_runtime/entities/provider_entities.py deleted file mode 100644 index 8e6c516fb9..0000000000 --- a/api/graphon/model_runtime/entities/provider_entities.py +++ /dev/null @@ -1,179 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum, auto - -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType - - -class ConfigurateMethod(StrEnum): - """ - Enum class for configurate method of provider model. - """ - - PREDEFINED_MODEL = "predefined-model" - CUSTOMIZABLE_MODEL = "customizable-model" - - -class FormType(StrEnum): - """ - Enum class for form type. - """ - - TEXT_INPUT = "text-input" - SECRET_INPUT = "secret-input" - SELECT = auto() - RADIO = auto() - SWITCH = auto() - - -class FormShowOnObject(BaseModel): - """ - Model class for form show on. - """ - - variable: str - value: str - - -class FormOption(BaseModel): - """ - Model class for form option. - """ - - label: I18nObject - value: str - show_on: list[FormShowOnObject] = [] - - @model_validator(mode="after") - def _(self): - if not self.label: - self.label = I18nObject(en_US=self.value) - return self - - -class CredentialFormSchema(BaseModel): - """ - Model class for credential form schema. - """ - - variable: str - label: I18nObject - type: FormType - required: bool = True - default: str | None = None - options: list[FormOption] | None = None - placeholder: I18nObject | None = None - max_length: int = 0 - show_on: list[FormShowOnObject] = [] - - -class ProviderCredentialSchema(BaseModel): - """ - Model class for provider credential schema. - """ - - credential_form_schemas: list[CredentialFormSchema] - - -class FieldModelSchema(BaseModel): - label: I18nObject - placeholder: I18nObject | None = None - - -class ModelCredentialSchema(BaseModel): - """ - Model class for model credential schema. - """ - - model: FieldModelSchema - credential_form_schemas: list[CredentialFormSchema] - - -class SimpleProviderEntity(BaseModel): - """ - Simplified provider schema exposed to callers. - - `provider` is the canonical runtime identifier. `provider_name` is an optional - compatibility alias for short-name lookups and is empty when no alias exists. - """ - - provider: str - provider_name: str = "" - label: I18nObject - icon_small: I18nObject | None = None - icon_small_dark: I18nObject | None = None - supported_model_types: Sequence[ModelType] - models: list[AIModelEntity] = [] - - -class ProviderHelpEntity(BaseModel): - """ - Model class for provider help. - """ - - title: I18nObject - url: I18nObject - - -class ProviderEntity(BaseModel): - """ - Runtime-native provider schema. - - `provider` is the canonical runtime identifier. `provider_name` is a - compatibility alias for callers that still resolve providers by short name and - is empty when no alias exists. - """ - - provider: str - provider_name: str = "" - label: I18nObject - description: I18nObject | None = None - icon_small: I18nObject | None = None - icon_small_dark: I18nObject | None = None - background: str | None = None - help: ProviderHelpEntity | None = None - supported_model_types: Sequence[ModelType] - configurate_methods: list[ConfigurateMethod] - models: list[AIModelEntity] = Field(default_factory=list) - provider_credential_schema: ProviderCredentialSchema | None = None - model_credential_schema: ModelCredentialSchema | None = None - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - # position from plugin _position.yaml - position: dict[str, list[str]] | None = {} - - @field_validator("models", mode="before") - @classmethod - def validate_models(cls, v): - # returns EmptyList if v is empty - if not v: - return [] - return v - - def to_simple_provider(self) -> SimpleProviderEntity: - """ - Convert to simple provider. - - :return: simple provider - """ - return SimpleProviderEntity( - provider=self.provider, - provider_name=self.provider_name, - label=self.label, - icon_small=self.icon_small, - supported_model_types=self.supported_model_types, - models=self.models, - ) - - -class ProviderConfig(BaseModel): - """ - Model class for provider config. - """ - - provider: str - credentials: dict diff --git a/api/graphon/model_runtime/entities/rerank_entities.py b/api/graphon/model_runtime/entities/rerank_entities.py deleted file mode 100644 index 8a0bb5fac2..0000000000 --- a/api/graphon/model_runtime/entities/rerank_entities.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import TypedDict - -from pydantic import BaseModel - - -class MultimodalRerankInput(TypedDict): - content: str - content_type: str - - -class RerankDocument(BaseModel): - """ - Model class for rerank document. - """ - - index: int - text: str - score: float - - -class RerankResult(BaseModel): - """ - Model class for rerank result. - """ - - model: str - docs: list[RerankDocument] diff --git a/api/graphon/model_runtime/entities/text_embedding_entities.py b/api/graphon/model_runtime/entities/text_embedding_entities.py deleted file mode 100644 index 08ffd83b5b..0000000000 --- a/api/graphon/model_runtime/entities/text_embedding_entities.py +++ /dev/null @@ -1,47 +0,0 @@ -from decimal import Decimal -from enum import StrEnum, auto - -from pydantic import BaseModel - -from graphon.model_runtime.entities.model_entities import ModelUsage - - -class EmbeddingInputType(StrEnum): - """Embedding request input variants understood by the model runtime.""" - - DOCUMENT = auto() - QUERY = auto() - - -class EmbeddingUsage(ModelUsage): - """ - Model class for embedding usage. - """ - - tokens: int - total_tokens: int - unit_price: Decimal - price_unit: Decimal - total_price: Decimal - currency: str - latency: float - - -class EmbeddingResult(BaseModel): - """ - Model class for text embedding result. - """ - - model: str - embeddings: list[list[float]] - usage: EmbeddingUsage - - -class FileEmbeddingResult(BaseModel): - """ - Model class for file embedding result. - """ - - model: str - embeddings: list[list[float]] - usage: EmbeddingUsage diff --git a/api/graphon/model_runtime/errors/__init__.py b/api/graphon/model_runtime/errors/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/model_runtime/errors/invoke.py b/api/graphon/model_runtime/errors/invoke.py deleted file mode 100644 index 1a57078b98..0000000000 --- a/api/graphon/model_runtime/errors/invoke.py +++ /dev/null @@ -1,41 +0,0 @@ -class InvokeError(ValueError): - """Base class for all LLM exceptions.""" - - description: str | None = None - - def __init__(self, description: str | None = None): - if description is not None: - self.description = description - - def __str__(self): - return self.description or self.__class__.__name__ - - -class InvokeConnectionError(InvokeError): - """Raised when the Invoke returns connection error.""" - - description = "Connection Error" - - -class InvokeServerUnavailableError(InvokeError): - """Raised when the Invoke returns server unavailable error.""" - - description = "Server Unavailable Error" - - -class InvokeRateLimitError(InvokeError): - """Raised when the Invoke returns rate limit error.""" - - description = "Rate Limit Error" - - -class InvokeAuthorizationError(InvokeError): - """Raised when the Invoke returns authorization error.""" - - description = "Incorrect model credentials provided, please check and try again. " - - -class InvokeBadRequestError(InvokeError): - """Raised when the Invoke returns bad request.""" - - description = "Bad Request Error" diff --git a/api/graphon/model_runtime/errors/validate.py b/api/graphon/model_runtime/errors/validate.py deleted file mode 100644 index 16bebcc67d..0000000000 --- a/api/graphon/model_runtime/errors/validate.py +++ /dev/null @@ -1,6 +0,0 @@ -class CredentialsValidateFailedError(ValueError): - """ - Credentials validate failed error - """ - - pass diff --git a/api/graphon/model_runtime/memory/__init__.py b/api/graphon/model_runtime/memory/__init__.py deleted file mode 100644 index 2d954486c3..0000000000 --- a/api/graphon/model_runtime/memory/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory - -__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"] diff --git a/api/graphon/model_runtime/memory/prompt_message_memory.py b/api/graphon/model_runtime/memory/prompt_message_memory.py deleted file mode 100644 index 03e26e9ff5..0000000000 --- a/api/graphon/model_runtime/memory/prompt_message_memory.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from typing import Protocol - -from graphon.model_runtime.entities import PromptMessage - -DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000 - - -class PromptMessageMemory(Protocol): - """Port for loading memory as prompt messages.""" - - def get_history_prompt_messages( - self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None - ) -> Sequence[PromptMessage]: - """Return historical prompt messages constrained by token/message limits.""" - ... diff --git a/api/graphon/model_runtime/model_providers/__base/__init__.py b/api/graphon/model_runtime/model_providers/__base/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/model_runtime/model_providers/__base/ai_model.py b/api/graphon/model_runtime/model_providers/__base/ai_model.py deleted file mode 100644 index 1700ec9740..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/ai_model.py +++ /dev/null @@ -1,247 +0,0 @@ -import decimal - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE -from graphon.model_runtime.entities.model_entities import ( - AIModelEntity, - DefaultParameterName, - ModelType, - PriceConfig, - PriceInfo, - PriceType, -) -from graphon.model_runtime.entities.provider_entities import ProviderEntity -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.runtime import ModelRuntime - - -class AIModel: - """ - Runtime-facing base class for all model providers. - - This stays a regular Python class because instances hold live collaborators - such as the provider schema and runtime adapter rather than user input that - benefits from Pydantic validation. Subclasses must pin ``model_type`` via a - class attribute; the base class is not meant to be instantiated directly. - """ - - model_type: ModelType - provider_schema: ProviderEntity - model_runtime: ModelRuntime - started_at: float - - def __init__( - self, - provider_schema: ProviderEntity, - model_runtime: ModelRuntime, - *, - started_at: float = 0, - ) -> None: - if getattr(type(self), "model_type", None) is None: - raise TypeError("AIModel subclasses must define model_type as a class attribute") - - self.model_type = type(self).model_type - self.provider_schema = provider_schema - self.model_runtime = model_runtime - self.started_at = started_at - - @property - def provider(self) -> str: - return self.provider_schema.provider - - @property - def provider_display_name(self) -> str: - return self.provider_schema.label.en_US - - @property - def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]: - """ - Map model invoke error to unified error. - - The key is the error type thrown to the caller, and the value contains - runtime-facing exception types that should be normalized to it. - """ - return { - InvokeConnectionError: [InvokeConnectionError], - InvokeServerUnavailableError: [InvokeServerUnavailableError], - InvokeRateLimitError: [InvokeRateLimitError], - InvokeAuthorizationError: [InvokeAuthorizationError], - InvokeBadRequestError: [InvokeBadRequestError], - ValueError: [ValueError], - } - - def _transform_invoke_error(self, error: Exception) -> Exception: - """ - Transform invoke error to unified error - - :param error: model invoke error - :return: unified error - """ - for invoke_error, model_errors in self._invoke_error_mapping.items(): - if isinstance(error, tuple(model_errors)): - if invoke_error == InvokeAuthorizationError: - return InvokeAuthorizationError( - description=( - f"[{self.provider_display_name}] Incorrect model credentials provided, " - "please check and try again." - ) - ) - elif isinstance(invoke_error, InvokeError): - return InvokeError( - description=f"[{self.provider_display_name}] {invoke_error.description}, {str(error)}" - ) - else: - return error - - return InvokeError(description=f"[{self.provider_display_name}] Error: {str(error)}") - - def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo: - """ - Get price for given model and tokens - - :param model: model name - :param credentials: model credentials - :param price_type: price type - :param tokens: number of tokens - :return: price info - """ - # get model schema - model_schema = self.get_model_schema(model, credentials) - - # get price info from predefined model schema - price_config: PriceConfig | None = None - if model_schema and model_schema.pricing: - price_config = model_schema.pricing - - # get unit price - unit_price = None - if price_config: - if price_type == PriceType.INPUT: - unit_price = price_config.input - elif price_type == PriceType.OUTPUT and price_config.output is not None: - unit_price = price_config.output - - if unit_price is None: - return PriceInfo( - unit_price=decimal.Decimal("0.0"), - unit=decimal.Decimal("0.0"), - total_amount=decimal.Decimal("0.0"), - currency="USD", - ) - - # calculate total amount - if not price_config: - raise ValueError(f"Price config not found for model {model}") - total_amount = tokens * unit_price * price_config.unit - total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP) - - return PriceInfo( - unit_price=unit_price, - unit=price_config.unit, - total_amount=total_amount, - currency=price_config.currency, - ) - - def get_model_schema(self, model: str, credentials: dict | None = None) -> AIModelEntity | None: - """ - Get model schema by model name and credentials - - :param model: model name - :param credentials: model credentials - :return: model schema - """ - return self.model_runtime.get_model_schema( - provider=self.provider, - model_type=self.model_type, - model=model, - credentials=credentials or {}, - ) - - def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: - """ - Get customizable model schema from credentials - - :param model: model name - :param credentials: model credentials - :return: model schema - """ - - # get customizable model schema - schema = self.get_customizable_model_schema(model, credentials) - if not schema: - return None - - # fill in the template - new_parameter_rules = [] - for parameter_rule in schema.parameter_rules: - if parameter_rule.use_template: - try: - default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) - default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) - if not parameter_rule.max and "max" in default_parameter_rule: - parameter_rule.max = default_parameter_rule["max"] - if not parameter_rule.min and "min" in default_parameter_rule: - parameter_rule.min = default_parameter_rule["min"] - if not parameter_rule.default and "default" in default_parameter_rule: - parameter_rule.default = default_parameter_rule["default"] - if not parameter_rule.precision and "precision" in default_parameter_rule: - parameter_rule.precision = default_parameter_rule["precision"] - if not parameter_rule.required and "required" in default_parameter_rule: - parameter_rule.required = default_parameter_rule["required"] - if not parameter_rule.help and "help" in default_parameter_rule: - parameter_rule.help = I18nObject( - en_US=default_parameter_rule["help"]["en_US"], - ) - if ( - parameter_rule.help - and not parameter_rule.help.en_US - and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"]) - ): - parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"] - if ( - parameter_rule.help - and not parameter_rule.help.zh_Hans - and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"]) - ): - parameter_rule.help.zh_Hans = default_parameter_rule["help"].get( - "zh_Hans", default_parameter_rule["help"]["en_US"] - ) - except ValueError: - pass - - new_parameter_rules.append(parameter_rule) - - schema.parameter_rules = new_parameter_rules - - return schema - - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: - """ - Get customizable model schema - - :param model: model name - :param credentials: model credentials - :return: model schema - """ - return None - - def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName): - """ - Get default parameter rule for given name - - :param name: parameter name - :return: parameter rule - """ - default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name) - - if not default_parameter_rule: - raise Exception(f"Invalid model parameter rule name {name}") - - return default_parameter_rule diff --git a/api/graphon/model_runtime/model_providers/__base/large_language_model.py b/api/graphon/model_runtime/model_providers/__base/large_language_model.py deleted file mode 100644 index 0f909646a1..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/large_language_model.py +++ /dev/null @@ -1,638 +0,0 @@ -import logging -import time -import uuid -from collections.abc import Callable, Generator, Iterator, Mapping, Sequence -from typing import Union - -from graphon.model_runtime.callbacks.base_callback import Callback -from graphon.model_runtime.callbacks.logging_callback import LoggingCallback -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageContentUnionTypes, - PromptMessageTool, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ( - ModelType, - PriceType, -) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -def _gen_tool_call_id() -> str: - return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" - - -def _run_callbacks(callbacks: Sequence[Callback] | None, *, event: str, invoke: Callable[[Callback], None]) -> None: - if not callbacks: - return - - for callback in callbacks: - try: - invoke(callback) - except Exception as e: - if callback.raise_error: - raise - logger.warning("Callback %s %s failed with error %s", callback.__class__.__name__, event, e) - - -def _get_or_create_tool_call( - existing_tools_calls: list[AssistantPromptMessage.ToolCall], - tool_call_id: str, -) -> AssistantPromptMessage.ToolCall: - """ - Get or create a tool call by ID. - - If `tool_call_id` is empty, returns the most recently created tool call. - """ - if not tool_call_id: - if not existing_tools_calls: - raise ValueError("tool_call_id is empty but no existing tool call is available to apply the delta") - return existing_tools_calls[-1] - - tool_call = next((tool_call for tool_call in existing_tools_calls if tool_call.id == tool_call_id), None) - if tool_call is None: - tool_call = AssistantPromptMessage.ToolCall( - id=tool_call_id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), - ) - existing_tools_calls.append(tool_call) - - return tool_call - - -def _merge_tool_call_delta( - tool_call: AssistantPromptMessage.ToolCall, - delta: AssistantPromptMessage.ToolCall, -) -> None: - if delta.id: - tool_call.id = delta.id - if delta.type: - tool_call.type = delta.type - if delta.function.name: - tool_call.function.name = delta.function.name - if delta.function.arguments: - tool_call.function.arguments += delta.function.arguments - - -def _build_llm_result_from_chunks( - model: str, - prompt_messages: Sequence[PromptMessage], - chunks: Iterator[LLMResultChunk], -) -> LLMResult: - """ - Build a single `LLMResult` by accumulating all returned chunks. - - Some models only support streaming output (e.g. Qwen3 open-source edition) - and the plugin side may still implement the response via a chunked stream, - so all chunks must be consumed and concatenated into a single ``LLMResult``. - - The ``usage`` is taken from the last chunk that carries it, which is the - typical convention for streaming responses (the final chunk contains the - aggregated token counts). - """ - content = "" - content_list: list[PromptMessageContentUnionTypes] = [] - usage = LLMUsage.empty_usage() - system_fingerprint: str | None = None - tools_calls: list[AssistantPromptMessage.ToolCall] = [] - - try: - for chunk in chunks: - if isinstance(chunk.delta.message.content, str): - content += chunk.delta.message.content - elif isinstance(chunk.delta.message.content, list): - content_list.extend(chunk.delta.message.content) - - if chunk.delta.message.tool_calls: - _increase_tool_call(chunk.delta.message.tool_calls, tools_calls) - - if chunk.delta.usage: - usage = chunk.delta.usage - if chunk.system_fingerprint: - system_fingerprint = chunk.system_fingerprint - except Exception: - logger.exception("Error while consuming non-stream plugin chunk iterator.") - raise - finally: - # Drain any remaining chunks to release underlying streaming resources (e.g. HTTP connections). - close = getattr(chunks, "close", None) - if callable(close): - close() - - return LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=content or content_list, - tool_calls=tools_calls, - ), - usage=usage, - system_fingerprint=system_fingerprint, - ) - - -def _invoke_llm_via_runtime( - *, - llm_model: "LargeLanguageModel", - provider: str, - model: str, - credentials: dict, - model_parameters: dict, - prompt_messages: Sequence[PromptMessage], - tools: list[PromptMessageTool] | None, - stop: Sequence[str] | None, - stream: bool, -) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: - return llm_model.model_runtime.invoke_llm( - provider=provider, - model=model, - credentials=credentials, - model_parameters=model_parameters, - prompt_messages=list(prompt_messages), - tools=tools, - stop=stop, - stream=stream, - ) - - -def _normalize_non_stream_runtime_result( - model: str, - prompt_messages: Sequence[PromptMessage], - result: Union[LLMResult, Iterator[LLMResultChunk]], -) -> LLMResult: - if isinstance(result, LLMResult): - return result - return _build_llm_result_from_chunks(model=model, prompt_messages=prompt_messages, chunks=result) - - -def _increase_tool_call( - new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall] -): - """ - Merge incremental tool call updates into existing tool calls. - - :param new_tool_calls: List of new tool call deltas to be merged. - :param existing_tools_calls: List of existing tool calls to be modified IN-PLACE. - """ - - for new_tool_call in new_tool_calls: - # generate ID for tool calls with function name but no ID to track them - if new_tool_call.function.name and not new_tool_call.id: - new_tool_call.id = _gen_tool_call_id() - - tool_call = _get_or_create_tool_call(existing_tools_calls, new_tool_call.id) - _merge_tool_call_delta(tool_call, new_tool_call) - - -class LargeLanguageModel(AIModel): - """ - Model class for large language model. - """ - - model_type: ModelType = ModelType.LLM - - def invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict | None = None, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, - stream: bool = True, - callbacks: list[Callback] | None = None, - ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: - """ - Invoke large language model - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param callbacks: callbacks - :return: full response or stream response chunk generator result - """ - # validate and filter model parameters - if model_parameters is None: - model_parameters = {} - - self.started_at = time.perf_counter() - - callbacks = callbacks or [] - - if logger.isEnabledFor(logging.DEBUG): - callbacks.append(LoggingCallback()) - - # trigger before invoke callbacks - self._trigger_before_invoke_callbacks( - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - callbacks=callbacks, - ) - - result: Union[LLMResult, Generator[LLMResultChunk, None, None]] - - try: - result = _invoke_llm_via_runtime( - llm_model=self, - provider=self.provider, - model=model, - credentials=credentials, - model_parameters=model_parameters, - prompt_messages=prompt_messages, - tools=tools, - stop=stop, - stream=stream, - ) - - if not stream: - result = _normalize_non_stream_runtime_result( - model=model, prompt_messages=prompt_messages, result=result - ) - except Exception as e: - self._trigger_invoke_error_callbacks( - model=model, - ex=e, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - callbacks=callbacks, - ) - - # TODO - raise self._transform_invoke_error(e) - - if stream and not isinstance(result, LLMResult): - return self._invoke_result_generator( - model=model, - result=result, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - callbacks=callbacks, - ) - elif isinstance(result, LLMResult): - self._trigger_after_invoke_callbacks( - model=model, - result=result, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - callbacks=callbacks, - ) - # Following https://github.com/langgenius/dify/issues/17799, - # we removed the prompt_messages from the chunk on the plugin daemon side. - # To ensure compatibility, we add the prompt_messages back here. - result.prompt_messages = prompt_messages - return result - raise NotImplementedError("unsupported invoke result type", type(result)) - - def _invoke_result_generator( - self, - model: str, - result: Generator[LLMResultChunk, None, None], - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - invocation_context: Mapping[str, object] | None = None, - callbacks: list[Callback] | None = None, - ) -> Generator[LLMResultChunk, None, None]: - """ - Invoke result generator - - :param result: result generator - :return: result generator - """ - callbacks = callbacks or [] - message_content: list[PromptMessageContentUnionTypes] = [] - usage = None - system_fingerprint = None - real_model = model - - def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None): - if not content: - return - if isinstance(content, list): - message_content.extend(content) - return - if isinstance(content, str): - message_content.append(TextPromptMessageContent(data=content)) - return - - try: - for chunk in result: - # Following https://github.com/langgenius/dify/issues/17799, - # we removed the prompt_messages from the chunk on the plugin daemon side. - # To ensure compatibility, we add the prompt_messages back here. - chunk.prompt_messages = prompt_messages - yield chunk - - self._trigger_new_chunk_callbacks( - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - callbacks=callbacks, - ) - - _update_message_content(chunk.delta.message.content) - - real_model = chunk.model - if chunk.delta.usage: - usage = chunk.delta.usage - - if chunk.system_fingerprint: - system_fingerprint = chunk.system_fingerprint - except Exception as e: - raise self._transform_invoke_error(e) - - assistant_message = AssistantPromptMessage(content=message_content) - self._trigger_after_invoke_callbacks( - model=model, - result=LLMResult( - model=real_model, - prompt_messages=prompt_messages, - message=assistant_message, - usage=usage or LLMUsage.empty_usage(), - system_fingerprint=system_fingerprint, - ), - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - callbacks=callbacks, - ) - - def get_num_tokens( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None, - ) -> int: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param tools: tools for tool calling - :return: - """ - return self.model_runtime.get_llm_num_tokens( - provider=self.provider, - model_type=self.model_type, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, - ) - - def calc_response_usage( - self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int - ) -> LLMUsage: - """ - Calculate response usage - - :param model: model name - :param credentials: model credentials - :param prompt_tokens: prompt tokens - :param completion_tokens: completion tokens - :return: usage - """ - # get prompt price info - prompt_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=prompt_tokens, - ) - - # get completion price info - completion_price_info = self.get_price( - model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens - ) - - # transform usage - usage = LLMUsage( - prompt_tokens=prompt_tokens, - prompt_unit_price=prompt_price_info.unit_price, - prompt_price_unit=prompt_price_info.unit, - prompt_price=prompt_price_info.total_amount, - completion_tokens=completion_tokens, - completion_unit_price=completion_price_info.unit_price, - completion_price_unit=completion_price_info.unit, - completion_price=completion_price_info.total_amount, - total_tokens=prompt_tokens + completion_tokens, - total_price=prompt_price_info.total_amount + completion_price_info.total_amount, - currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at, - ) - - return usage - - def _trigger_before_invoke_callbacks( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - invocation_context: Mapping[str, object] | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger before invoke callbacks - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - :param callbacks: callbacks - """ - _run_callbacks( - callbacks, - event="on_before_invoke", - invoke=lambda callback: callback.on_before_invoke( - llm_instance=self, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - ), - ) - - def _trigger_new_chunk_callbacks( - self, - chunk: LLMResultChunk, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - invocation_context: Mapping[str, object] | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger new chunk callbacks - - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - """ - _run_callbacks( - callbacks, - event="on_new_chunk", - invoke=lambda callback: callback.on_new_chunk( - llm_instance=self, - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - ), - ) - - def _trigger_after_invoke_callbacks( - self, - model: str, - result: LLMResult, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - invocation_context: Mapping[str, object] | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger after invoke callbacks - - :param model: model name - :param result: result - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - :param callbacks: callbacks - """ - _run_callbacks( - callbacks, - event="on_after_invoke", - invoke=lambda callback: callback.on_after_invoke( - llm_instance=self, - result=result, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - ), - ) - - def _trigger_invoke_error_callbacks( - self, - model: str, - ex: Exception, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - invocation_context: Mapping[str, object] | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger invoke error callbacks - - :param model: model name - :param ex: exception - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param invocation_context: opaque request metadata for the current invocation - :param callbacks: callbacks - """ - _run_callbacks( - callbacks, - event="on_invoke_error", - invoke=lambda callback: callback.on_invoke_error( - llm_instance=self, - ex=ex, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - invocation_context=invocation_context, - ), - ) diff --git a/api/graphon/model_runtime/model_providers/__base/moderation_model.py b/api/graphon/model_runtime/model_providers/__base/moderation_model.py deleted file mode 100644 index 01f6842998..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/moderation_model.py +++ /dev/null @@ -1,33 +0,0 @@ -import time - -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - - -class ModerationModel(AIModel): - """ - Model class for moderation model. - """ - - model_type: ModelType = ModelType.MODERATION - - def invoke(self, model: str, credentials: dict, text: str) -> bool: - """ - Invoke moderation model - - :param model: model name - :param credentials: model credentials - :param text: text to moderate - :return: false if text is safe, true otherwise - """ - self.started_at = time.perf_counter() - - try: - return self.model_runtime.invoke_moderation( - provider=self.provider, - model=model, - credentials=credentials, - text=text, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/graphon/model_runtime/model_providers/__base/rerank_model.py b/api/graphon/model_runtime/model_providers/__base/rerank_model.py deleted file mode 100644 index 94b2b5a4fb..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/rerank_model.py +++ /dev/null @@ -1,76 +0,0 @@ -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - - -class RerankModel(AIModel): - """ - Base Model class for rerank model. - """ - - model_type: ModelType = ModelType.RERANK - - def invoke( - self, - model: str, - credentials: dict, - query: str, - docs: list[str], - score_threshold: float | None = None, - top_n: int | None = None, - ) -> RerankResult: - """ - Invoke rerank model - - :param model: model name - :param credentials: model credentials - :param query: search query - :param docs: docs for reranking - :param score_threshold: score threshold - :param top_n: top n - :return: rerank result - """ - try: - return self.model_runtime.invoke_rerank( - provider=self.provider, - model=model, - credentials=credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ) - except Exception as e: - raise self._transform_invoke_error(e) - - def invoke_multimodal_rerank( - self, - model: str, - credentials: dict, - query: MultimodalRerankInput, - docs: list[MultimodalRerankInput], - score_threshold: float | None = None, - top_n: int | None = None, - ) -> RerankResult: - """ - Invoke multimodal rerank model - :param model: model name - :param credentials: model credentials - :param query: search query - :param docs: docs for reranking - :param score_threshold: score threshold - :param top_n: top n - :return: rerank result - """ - try: - return self.model_runtime.invoke_multimodal_rerank( - provider=self.provider, - model=model, - credentials=credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/graphon/model_runtime/model_providers/__base/speech2text_model.py b/api/graphon/model_runtime/model_providers/__base/speech2text_model.py deleted file mode 100644 index 4f5d648639..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/speech2text_model.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import IO - -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - - -class Speech2TextModel(AIModel): - """ - Model class for speech2text model. - """ - - model_type: ModelType = ModelType.SPEECH2TEXT - - def invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str: - """ - Invoke speech to text model - - :param model: model name - :param credentials: model credentials - :param file: audio file - :return: text for given audio file - """ - try: - return self.model_runtime.invoke_speech_to_text( - provider=self.provider, - model=model, - credentials=credentials, - file=file, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/graphon/model_runtime/model_providers/__base/text_embedding_model.py b/api/graphon/model_runtime/model_providers/__base/text_embedding_model.py deleted file mode 100644 index c8b4a0a6af..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/text_embedding_model.py +++ /dev/null @@ -1,98 +0,0 @@ -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - - -class TextEmbeddingModel(AIModel): - """ - Model class for text embedding model. - """ - - model_type: ModelType = ModelType.TEXT_EMBEDDING - - def invoke( - self, - model: str, - credentials: dict, - texts: list[str] | None = None, - multimodel_documents: list[dict] | None = None, - input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, - ) -> EmbeddingResult: - """ - Invoke text embedding model - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :param files: files to embed - :param input_type: input type - :return: embeddings result - """ - try: - if texts: - return self.model_runtime.invoke_text_embedding( - provider=self.provider, - model=model, - credentials=credentials, - texts=texts, - input_type=input_type, - ) - if multimodel_documents: - return self.model_runtime.invoke_multimodal_embedding( - provider=self.provider, - model=model, - credentials=credentials, - documents=multimodel_documents, - input_type=input_type, - ) - raise ValueError("No texts or files provided") - except Exception as e: - raise self._transform_invoke_error(e) - - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> list[int]: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :return: - """ - return self.model_runtime.get_text_embedding_num_tokens( - provider=self.provider, - model=model, - credentials=credentials, - texts=texts, - ) - - def _get_context_size(self, model: str, credentials: dict) -> int: - """ - Get context size for given embedding model - - :param model: model name - :param credentials: model credentials - :return: context size - """ - model_schema = self.get_model_schema(model, credentials) - - if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties: - content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] - return content_size - - return 1000 - - def _get_max_chunks(self, model: str, credentials: dict) -> int: - """ - Get max chunks for given embedding model - - :param model: model name - :param credentials: model credentials - :return: max chunks - """ - model_schema = self.get_model_schema(model, credentials) - - if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: - max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] - return max_chunks - - return 1 diff --git a/api/graphon/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/graphon/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py deleted file mode 100644 index 3967acf07b..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py +++ /dev/null @@ -1,53 +0,0 @@ -import logging -from threading import Lock -from typing import Any - -logger = logging.getLogger(__name__) - -_tokenizer: Any | None = None -_lock = Lock() - - -class GPT2Tokenizer: - @staticmethod - def _get_num_tokens_by_gpt2(text: str) -> int: - """ - use gpt2 tokenizer to get num tokens - """ - _tokenizer = GPT2Tokenizer.get_encoder() - tokens = _tokenizer.encode(text) # type: ignore - return len(tokens) - - @staticmethod - def get_num_tokens(text: str) -> int: - # Because this process needs more cpu resource, we turn this back before we find a better way to handle it. - # - # future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text) - # result = future.result() - # return cast(int, result) - return GPT2Tokenizer._get_num_tokens_by_gpt2(text) - - @staticmethod - def get_encoder(): - global _tokenizer, _lock - if _tokenizer is not None: - return _tokenizer - with _lock: - if _tokenizer is None: - # Try to use tiktoken to get the tokenizer because it is faster - # - try: - import tiktoken - - _tokenizer = tiktoken.get_encoding("gpt2") - except Exception: - from os.path import abspath, dirname, join - - from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer - - base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), "gpt2") - _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) - logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken") - - return _tokenizer diff --git a/api/graphon/model_runtime/model_providers/__base/tts_model.py b/api/graphon/model_runtime/model_providers/__base/tts_model.py deleted file mode 100644 index 6846f3c403..0000000000 --- a/api/graphon/model_runtime/model_providers/__base/tts_model.py +++ /dev/null @@ -1,58 +0,0 @@ -import logging -from collections.abc import Iterable - -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -class TTSModel(AIModel): - """ - Model class for TTS model. - """ - - model_type: ModelType = ModelType.TTS - - def invoke( - self, - model: str, - credentials: dict, - content_text: str, - voice: str, - ) -> Iterable[bytes]: - """ - Invoke large language model - - :param model: model name - :param credentials: model credentials - :param voice: model timbre - :param content_text: text content to be translated - :return: translated audio file - """ - try: - return self.model_runtime.invoke_tts( - provider=self.provider, - model=model, - credentials=credentials, - content_text=content_text, - voice=voice, - ) - except Exception as e: - raise self._transform_invoke_error(e) - - def get_tts_model_voices(self, model: str, credentials: dict, language: str | None = None): - """ - Retrieves the list of voices supported by a given text-to-speech (TTS) model. - - :param language: The language for which the voices are requested. - :param model: The name of the TTS model. - :param credentials: The credentials required to access the TTS model. - :return: A list of voices supported by the TTS model. - """ - return self.model_runtime.get_tts_model_voices( - provider=self.provider, - model=model, - credentials=credentials, - language=language, - ) diff --git a/api/graphon/model_runtime/model_providers/__init__.py b/api/graphon/model_runtime/model_providers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/model_runtime/model_providers/_position.yaml b/api/graphon/model_runtime/model_providers/_position.yaml deleted file mode 100644 index fb02de3a67..0000000000 --- a/api/graphon/model_runtime/model_providers/_position.yaml +++ /dev/null @@ -1,43 +0,0 @@ -- openai -- deepseek -- anthropic -- azure_openai -- google -- vertex_ai -- nvidia -- nvidia_nim -- cohere -- upstage -- bedrock -- togetherai -- openrouter -- ollama -- mistralai -- groq -- replicate -- huggingface_hub -- xinference -- triton_inference_server -- zhipuai -- baichuan -- spark -- minimax -- tongyi -- wenxin -- moonshot -- tencent -- jina -- chatglm -- yi -- openllm -- localai -- volcengine_maas -- openai_api_compatible -- hunyuan -- siliconflow -- perfxcloud -- zhinao -- fireworks -- mixedbread -- nomic -- voyage diff --git a/api/graphon/model_runtime/model_providers/model_provider_factory.py b/api/graphon/model_runtime/model_providers/model_provider_factory.py deleted file mode 100644 index 1ea30c7120..0000000000 --- a/api/graphon/model_runtime/model_providers/model_provider_factory.py +++ /dev/null @@ -1,173 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence - -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity -from graphon.model_runtime.model_providers.__base.ai_model import AIModel -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel -from graphon.model_runtime.runtime import ModelRuntime -from graphon.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator -from graphon.model_runtime.schema_validators.provider_credential_schema_validator import ( - ProviderCredentialSchemaValidator, -) - - -class ModelProviderFactory: - """Factory for provider schemas and model-type instances backed by a runtime adapter.""" - - def __init__(self, model_runtime: ModelRuntime): - if model_runtime is None: - raise ValueError("model_runtime is required.") - self.model_runtime = model_runtime - - def get_providers(self) -> Sequence[ProviderEntity]: - """ - Get all providers. - """ - return list(self.get_model_providers()) - - def get_model_providers(self) -> Sequence[ProviderEntity]: - """ - Get all model providers exposed by the runtime adapter. - """ - return self.model_runtime.fetch_model_providers() - - def get_provider_schema(self, provider: str) -> ProviderEntity: - """ - Get provider schema. - """ - return self.get_model_provider(provider=provider) - - def get_model_provider(self, provider: str) -> ProviderEntity: - """ - Get provider schema. - """ - provider_entity = self._resolve_provider(provider) - if provider_entity is None: - raise ValueError(f"Invalid provider: {provider}") - - return provider_entity - - def provider_credentials_validate(self, *, provider: str, credentials: dict): - """ - Validate provider credentials. - """ - provider_entity = self.get_model_provider(provider=provider) - - provider_credential_schema = provider_entity.provider_credential_schema - if not provider_credential_schema: - raise ValueError(f"Provider {provider} does not have provider_credential_schema") - - validator = ProviderCredentialSchemaValidator(provider_credential_schema) - filtered_credentials = validator.validate_and_filter(credentials) - - self.model_runtime.validate_provider_credentials( - provider=provider_entity.provider, - credentials=filtered_credentials, - ) - - return filtered_credentials - - def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict): - """ - Validate model credentials. - """ - provider_entity = self.get_model_provider(provider=provider) - - model_credential_schema = provider_entity.model_credential_schema - if not model_credential_schema: - raise ValueError(f"Provider {provider} does not have model_credential_schema") - - validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) - filtered_credentials = validator.validate_and_filter(credentials) - - self.model_runtime.validate_model_credentials( - provider=provider_entity.provider, - model_type=model_type, - model=model, - credentials=filtered_credentials, - ) - - return filtered_credentials - - def get_model_schema( - self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None - ) -> AIModelEntity | None: - """ - Get model schema. - """ - provider_entity = self.get_model_provider(provider) - return self.model_runtime.get_model_schema( - provider=provider_entity.provider, - model_type=model_type, - model=model, - credentials=credentials or {}, - ) - - def get_models( - self, - *, - provider: str | None = None, - model_type: ModelType | None = None, - provider_configs: list[ProviderConfig] | None = None, - ) -> list[SimpleProviderEntity]: - """ - Get all models for given model type. - """ - providers = [] - for provider_entity in self.get_model_providers(): - if provider and not self._matches_provider(provider_entity, provider): - continue - - if model_type and model_type not in provider_entity.supported_model_types: - continue - - simple_provider_schema = provider_entity.to_simple_provider() - if model_type is not None: - simple_provider_schema.models = [ - model_schema for model_schema in provider_entity.models if model_schema.model_type == model_type - ] - providers.append(simple_provider_schema) - - return providers - - def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel: - """ - Get model type instance by provider name and model type. - """ - provider_schema = self.get_model_provider(provider) - - if model_type == ModelType.LLM: - return LargeLanguageModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - if model_type == ModelType.TEXT_EMBEDDING: - return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - if model_type == ModelType.RERANK: - return RerankModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - if model_type == ModelType.SPEECH2TEXT: - return Speech2TextModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - if model_type == ModelType.MODERATION: - return ModerationModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - if model_type == ModelType.TTS: - return TTSModel(provider_schema=provider_schema, model_runtime=self.model_runtime) - - raise ValueError(f"Unsupported model type: {model_type}") - - def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: - """ - Get provider icon. - """ - provider_entity = self.get_model_provider(provider) - return self.model_runtime.get_provider_icon(provider=provider_entity.provider, icon_type=icon_type, lang=lang) - - def _resolve_provider(self, provider: str) -> ProviderEntity | None: - return next((item for item in self.get_model_providers() if self._matches_provider(item, provider)), None) - - @staticmethod - def _matches_provider(provider_entity: ProviderEntity, provider: str) -> bool: - return provider in (provider_entity.provider, provider_entity.provider_name) diff --git a/api/graphon/model_runtime/runtime.py b/api/graphon/model_runtime/runtime.py deleted file mode 100644 index 79862bab8b..0000000000 --- a/api/graphon/model_runtime/runtime.py +++ /dev/null @@ -1,159 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator, Iterable, Sequence -from typing import IO, Any, Protocol, Union, runtime_checkable - -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.model_runtime.entities.provider_entities import ProviderEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult - - -@runtime_checkable -class ModelRuntime(Protocol): - """Port for provider discovery, schema lookup, and model execution. - - `provider` is the model runtime's canonical provider identifier. Adapters may - derive transport-specific details from it, but those details stay outside - this boundary. - """ - - def fetch_model_providers(self) -> Sequence[ProviderEntity]: ... - - def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: ... - - def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: ... - - def validate_model_credentials( - self, - *, - provider: str, - model_type: ModelType, - model: str, - credentials: dict[str, Any], - ) -> None: ... - - def get_model_schema( - self, - *, - provider: str, - model_type: ModelType, - model: str, - credentials: dict[str, Any], - ) -> AIModelEntity | None: ... - - def invoke_llm( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - model_parameters: dict[str, Any], - prompt_messages: Sequence[PromptMessage], - tools: list[PromptMessageTool] | None, - stop: Sequence[str] | None, - stream: bool, - ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: ... - - def get_llm_num_tokens( - self, - *, - provider: str, - model_type: ModelType, - model: str, - credentials: dict[str, Any], - prompt_messages: Sequence[PromptMessage], - tools: Sequence[PromptMessageTool] | None, - ) -> int: ... - - def invoke_text_embedding( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - texts: list[str], - input_type: EmbeddingInputType, - ) -> EmbeddingResult: ... - - def invoke_multimodal_embedding( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - documents: list[dict[str, Any]], - input_type: EmbeddingInputType, - ) -> EmbeddingResult: ... - - def get_text_embedding_num_tokens( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - texts: list[str], - ) -> list[int]: ... - - def invoke_rerank( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - query: str, - docs: list[str], - score_threshold: float | None, - top_n: int | None, - ) -> RerankResult: ... - - def invoke_multimodal_rerank( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - query: MultimodalRerankInput, - docs: list[MultimodalRerankInput], - score_threshold: float | None, - top_n: int | None, - ) -> RerankResult: ... - - def invoke_tts( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - content_text: str, - voice: str, - ) -> Iterable[bytes]: ... - - def get_tts_model_voices( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - language: str | None, - ) -> Any: ... - - def invoke_speech_to_text( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - file: IO[bytes], - ) -> str: ... - - def invoke_moderation( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - text: str, - ) -> bool: ... diff --git a/api/graphon/model_runtime/schema_validators/__init__.py b/api/graphon/model_runtime/schema_validators/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/model_runtime/schema_validators/common_validator.py b/api/graphon/model_runtime/schema_validators/common_validator.py deleted file mode 100644 index 984507081b..0000000000 --- a/api/graphon/model_runtime/schema_validators/common_validator.py +++ /dev/null @@ -1,92 +0,0 @@ -from typing import Union, cast - -from graphon.model_runtime.entities.provider_entities import CredentialFormSchema, FormType - - -class CommonValidator: - def _validate_and_filter_credential_form_schemas( - self, credential_form_schemas: list[CredentialFormSchema], credentials: dict - ): - need_validate_credential_form_schema_map = {} - for credential_form_schema in credential_form_schemas: - if not credential_form_schema.show_on: - need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema - continue - - all_show_on_match = True - for show_on_object in credential_form_schema.show_on: - if show_on_object.variable not in credentials: - all_show_on_match = False - break - - if credentials[show_on_object.variable] != show_on_object.value: - all_show_on_match = False - break - - if all_show_on_match: - need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema - - # Iterate over the remaining credential_form_schemas, verify each credential_form_schema - validated_credentials = {} - for credential_form_schema in need_validate_credential_form_schema_map.values(): - # add the value of the credential_form_schema corresponding to it to validated_credentials - result = self._validate_credential_form_schema(credential_form_schema, credentials) - if result: - validated_credentials[credential_form_schema.variable] = result - - return validated_credentials - - def _validate_credential_form_schema( - self, credential_form_schema: CredentialFormSchema, credentials: dict - ) -> Union[str, bool, None]: - """ - Validate credential form schema - - :param credential_form_schema: credential form schema - :param credentials: credentials - :return: validated credential form schema value - """ - # If the variable does not exist in credentials - value: Union[str, bool, None] = None - if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: - # If required is True, an exception is thrown - if credential_form_schema.required: - raise ValueError(f"Variable {credential_form_schema.variable} is required") - else: - # Get the value of default - if credential_form_schema.default: - # If it exists, add it to validated_credentials - return credential_form_schema.default - else: - # If default does not exist, skip - return None - - # Get the value corresponding to the variable from credentials - value = cast(str, credentials[credential_form_schema.variable]) - - # If max_length=0, no validation is performed - if credential_form_schema.max_length: - if len(value) > credential_form_schema.max_length: - raise ValueError( - f"Variable {credential_form_schema.variable} length should not be" - f" greater than {credential_form_schema.max_length}" - ) - - # check the type of value - if not isinstance(value, str): - raise ValueError(f"Variable {credential_form_schema.variable} should be string") - - if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}: - # If the value is in options, no validation is performed - if credential_form_schema.options: - if value not in [option.value for option in credential_form_schema.options]: - raise ValueError(f"Variable {credential_form_schema.variable} is not in options") - - if credential_form_schema.type == FormType.SWITCH: - # If the value is not in ['true', 'false'], an exception is thrown - if value.lower() not in {"true", "false"}: - raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") - - value = value.lower() == "true" - - return value diff --git a/api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py b/api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py deleted file mode 100644 index 9e4830c1b7..0000000000 --- a/api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py +++ /dev/null @@ -1,27 +0,0 @@ -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ModelCredentialSchema -from graphon.model_runtime.schema_validators.common_validator import CommonValidator - - -class ModelCredentialSchemaValidator(CommonValidator): - def __init__(self, model_type: ModelType, model_credential_schema: ModelCredentialSchema): - self.model_type = model_type - self.model_credential_schema = model_credential_schema - - def validate_and_filter(self, credentials: dict): - """ - Validate model credentials - - :param credentials: model credentials - :return: filtered credentials - """ - - if self.model_credential_schema is None: - raise ValueError("Model credential schema is None") - - # get the credential_form_schemas in provider_credential_schema - credential_form_schemas = self.model_credential_schema.credential_form_schemas - - credentials["__model_type"] = self.model_type.value - - return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py deleted file mode 100644 index 05fd3ce142..0000000000 --- a/api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py +++ /dev/null @@ -1,19 +0,0 @@ -from graphon.model_runtime.entities.provider_entities import ProviderCredentialSchema -from graphon.model_runtime.schema_validators.common_validator import CommonValidator - - -class ProviderCredentialSchemaValidator(CommonValidator): - def __init__(self, provider_credential_schema: ProviderCredentialSchema): - self.provider_credential_schema = provider_credential_schema - - def validate_and_filter(self, credentials: dict): - """ - Validate provider credentials - - :param credentials: provider credentials - :return: validated provider credentials - """ - # get the credential_form_schemas in provider_credential_schema - credential_form_schemas = self.provider_credential_schema.credential_form_schemas - - return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/api/graphon/model_runtime/utils/__init__.py b/api/graphon/model_runtime/utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/model_runtime/utils/encoders.py b/api/graphon/model_runtime/utils/encoders.py deleted file mode 100644 index 13abf74767..0000000000 --- a/api/graphon/model_runtime/utils/encoders.py +++ /dev/null @@ -1,218 +0,0 @@ -import dataclasses -import datetime -from collections import defaultdict, deque -from collections.abc import Callable, Sequence -from decimal import Decimal -from enum import Enum -from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network -from pathlib import Path, PurePath -from re import Pattern -from types import GeneratorType -from typing import Any, Literal, Union -from uuid import UUID - -from pydantic import BaseModel -from pydantic.networks import AnyUrl, NameEmail -from pydantic.types import SecretBytes, SecretStr -from pydantic_core import Url -from pydantic_extra_types.color import Color - - -def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: - return model.model_dump(mode=mode, **kwargs) - - -# Taken from Pydantic v1 as is -def isoformat(o: Union[datetime.date, datetime.time]) -> str: - return o.isoformat() - - -# Taken from Pydantic v1 as is -# TODO: pv2 should this return strings instead? -def decimal_encoder(dec_value: Decimal) -> Union[int, float]: - """ - Encodes a Decimal as int of there's no exponent, otherwise float - - This is useful when we use ConstrainedDecimal to represent Numeric(x,0) - where a integer (but not int typed) is used. Encoding this as a float - results in failed round-tripping between encode and parse. - Our Id type is a prime example of this. - - >>> decimal_encoder(Decimal("1.0")) - 1.0 - - >>> decimal_encoder(Decimal("1")) - 1 - """ - if dec_value.as_tuple().exponent >= 0: # type: ignore[operator] - return int(dec_value) - else: - return float(dec_value) - - -ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { - bytes: lambda o: o.decode(), - Color: str, - datetime.date: isoformat, - datetime.datetime: isoformat, - datetime.time: isoformat, - datetime.timedelta: lambda td: td.total_seconds(), - Decimal: decimal_encoder, - Enum: lambda o: o.value, - frozenset: list, - deque: list, - GeneratorType: list, - IPv4Address: str, - IPv4Interface: str, - IPv4Network: str, - IPv6Address: str, - IPv6Interface: str, - IPv6Network: str, - NameEmail: str, - Path: str, - Pattern: lambda o: o.pattern, - SecretBytes: str, - SecretStr: str, - set: list, - UUID: str, - Url: str, - AnyUrl: str, -} - - -def generate_encoders_by_class_tuples( - type_encoder_map: dict[Any, Callable[[Any], Any]], -) -> dict[Callable[[Any], Any], tuple[Any, ...]]: - encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple) - for type_, encoder in type_encoder_map.items(): - encoders_by_class_tuples[encoder] += (type_,) - return encoders_by_class_tuples - - -encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) - - -def jsonable_encoder( - obj: Any, - by_alias: bool = True, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - custom_encoder: dict[Any, Callable[[Any], Any]] | None = None, - excluded_key_prefixes: Sequence[str] = (), -) -> Any: - custom_encoder = custom_encoder or {} - if custom_encoder: - if type(obj) in custom_encoder: - return custom_encoder[type(obj)](obj) - else: - for encoder_type, encoder_instance in custom_encoder.items(): - if isinstance(obj, encoder_type): - return encoder_instance(obj) - if isinstance(obj, BaseModel): - obj_dict = _model_dump( - obj, - mode="json", - include=None, - exclude=None, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - exclude_defaults=exclude_defaults, - ) - if "__root__" in obj_dict: - obj_dict = obj_dict["__root__"] - return jsonable_encoder( - obj_dict, - exclude_none=exclude_none, - exclude_defaults=exclude_defaults, - excluded_key_prefixes=excluded_key_prefixes, - ) - if dataclasses.is_dataclass(obj): - # Ensure obj is a dataclass instance, not a dataclass type - if not isinstance(obj, type): - obj_dict = dataclasses.asdict(obj) - return jsonable_encoder( - obj_dict, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - excluded_key_prefixes=excluded_key_prefixes, - ) - if isinstance(obj, Enum): - return obj.value - if isinstance(obj, PurePath): - return str(obj) - if isinstance(obj, str | int | float | type(None)): - return obj - if isinstance(obj, Decimal): - return format(obj, "f") - if isinstance(obj, dict): - encoded_dict = {} - for key, value in obj.items(): - if isinstance(key, str) and any(key.startswith(prefix) for prefix in excluded_key_prefixes): - continue - if value is None and exclude_none: - continue - - encoded_key = jsonable_encoder( - key, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - excluded_key_prefixes=excluded_key_prefixes, - ) - encoded_value = jsonable_encoder( - value, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - excluded_key_prefixes=excluded_key_prefixes, - ) - encoded_dict[encoded_key] = encoded_value - return encoded_dict - if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque): - encoded_list = [] - for item in obj: - encoded_list.append( - jsonable_encoder( - item, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - excluded_key_prefixes=excluded_key_prefixes, - ) - ) - return encoded_list - - if type(obj) in ENCODERS_BY_TYPE: - return ENCODERS_BY_TYPE[type(obj)](obj) - for encoder, classes_tuple in encoders_by_class_tuples.items(): - if isinstance(obj, classes_tuple): - return encoder(obj) - - try: - data = dict(obj) # type: ignore - except Exception as e: - errors: list[Exception] = [] - errors.append(e) - try: - data = vars(obj) # type: ignore - except Exception as e: - errors.append(e) - raise ValueError(str(errors)) from e - return jsonable_encoder( - data, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - excluded_key_prefixes=excluded_key_prefixes, - ) diff --git a/api/graphon/node_events/__init__.py b/api/graphon/node_events/__init__.py deleted file mode 100644 index a2bbf9f176..0000000000 --- a/api/graphon/node_events/__init__.py +++ /dev/null @@ -1,48 +0,0 @@ -from .agent import AgentLogEvent -from .base import NodeEventBase, NodeRunResult -from .iteration import ( - IterationFailedEvent, - IterationNextEvent, - IterationStartedEvent, - IterationSucceededEvent, -) -from .loop import ( - LoopFailedEvent, - LoopNextEvent, - LoopStartedEvent, - LoopSucceededEvent, -) -from .node import ( - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - ModelInvokeCompletedEvent, - PauseRequestedEvent, - RunRetrieverResourceEvent, - RunRetryEvent, - StreamChunkEvent, - StreamCompletedEvent, - VariableUpdatedEvent, -) - -__all__ = [ - "AgentLogEvent", - "HumanInputFormFilledEvent", - "HumanInputFormTimeoutEvent", - "IterationFailedEvent", - "IterationNextEvent", - "IterationStartedEvent", - "IterationSucceededEvent", - "LoopFailedEvent", - "LoopNextEvent", - "LoopStartedEvent", - "LoopSucceededEvent", - "ModelInvokeCompletedEvent", - "NodeEventBase", - "NodeRunResult", - "PauseRequestedEvent", - "RunRetrieverResourceEvent", - "RunRetryEvent", - "StreamChunkEvent", - "StreamCompletedEvent", - "VariableUpdatedEvent", -] diff --git a/api/graphon/node_events/agent.py b/api/graphon/node_events/agent.py deleted file mode 100644 index bf295ec774..0000000000 --- a/api/graphon/node_events/agent.py +++ /dev/null @@ -1,18 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import Field - -from .base import NodeEventBase - - -class AgentLogEvent(NodeEventBase): - message_id: str = Field(..., description="id") - label: str = Field(..., description="label") - node_execution_id: str = Field(..., description="node execution id") - parent_id: str | None = Field(..., description="parent id") - error: str | None = Field(..., description="error") - status: str = Field(..., description="status") - data: Mapping[str, Any] = Field(..., description="data") - metadata: Mapping[str, Any] = Field(default_factory=dict, description="metadata") - node_id: str = Field(..., description="node id") diff --git a/api/graphon/node_events/base.py b/api/graphon/node_events/base.py deleted file mode 100644 index dcd1672428..0000000000 --- a/api/graphon/node_events/base.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import BaseModel, Field - -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage - - -class NodeEventBase(BaseModel): - """Base class for all node events""" - - pass - - -def _default_metadata(): - v: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} - return v - - -class NodeRunResult(BaseModel): - """ - Node Run Result. - """ - - status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.PENDING - - inputs: Mapping[str, Any] = Field(default_factory=dict) - process_data: Mapping[str, Any] = Field(default_factory=dict) - outputs: Mapping[str, Any] = Field(default_factory=dict) - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=_default_metadata) - llm_usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage) - - edge_source_handle: str = "source" # source handle id of node with multiple branches - - error: str = "" - error_type: str = "" - - # single step node run retry - retry_index: int = 0 diff --git a/api/graphon/node_events/iteration.py b/api/graphon/node_events/iteration.py deleted file mode 100644 index 744ddea628..0000000000 --- a/api/graphon/node_events/iteration.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import NodeEventBase - - -class IterationStartedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class IterationNextEvent(NodeEventBase): - index: int = Field(..., description="index") - pre_iteration_output: Any = None - - -class IterationSucceededEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class IterationFailedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/graphon/node_events/loop.py b/api/graphon/node_events/loop.py deleted file mode 100644 index 3ae230f9f6..0000000000 --- a/api/graphon/node_events/loop.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import NodeEventBase - - -class LoopStartedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class LoopNextEvent(NodeEventBase): - index: int = Field(..., description="index") - pre_loop_output: Any = None - - -class LoopSucceededEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class LoopFailedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/graphon/node_events/node.py b/api/graphon/node_events/node.py deleted file mode 100644 index 17f1494cf2..0000000000 --- a/api/graphon/node_events/node.py +++ /dev/null @@ -1,72 +0,0 @@ -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any - -from pydantic import Field - -from graphon.entities.pause_reason import PauseReason -from graphon.file import File -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import NodeRunResult -from graphon.variables.variables import Variable - -from .base import NodeEventBase - - -class RunRetrieverResourceEvent(NodeEventBase): - retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources") - context: str = Field(..., description="context") - context_files: list[File] | None = Field(default=None, description="context files") - - -class ModelInvokeCompletedEvent(NodeEventBase): - text: str - usage: LLMUsage - finish_reason: str | None = None - reasoning_content: str | None = None - structured_output: dict | None = None - - -class RunRetryEvent(NodeEventBase): - error: str = Field(..., description="error") - retry_index: int = Field(..., description="Retry attempt number") - start_at: datetime = Field(..., description="Retry start time") - - -class StreamChunkEvent(NodeEventBase): - # Spec-compliant fields - selector: Sequence[str] = Field( - ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" - ) - chunk: str = Field(..., description="the actual chunk content") - is_final: bool = Field(default=False, description="indicates if this is the last chunk") - - -class StreamCompletedEvent(NodeEventBase): - node_run_result: NodeRunResult = Field(..., description="run result") - - -class VariableUpdatedEvent(NodeEventBase): - """Notify the engine that a single variable should be applied to the shared pool.""" - - variable: Variable = Field(..., description="Updated variable payload to apply.") - - -class PauseRequestedEvent(NodeEventBase): - reason: PauseReason = Field(..., description="pause reason") - - -class HumanInputFormFilledEvent(NodeEventBase): - """Event emitted when a human input form is submitted.""" - - node_title: str - rendered_content: str - action_id: str - action_text: str - - -class HumanInputFormTimeoutEvent(NodeEventBase): - """Event emitted when a human input form times out.""" - - node_title: str - expiration_time: datetime diff --git a/api/graphon/nodes/__init__.py b/api/graphon/nodes/__init__.py deleted file mode 100644 index 2d376d104d..0000000000 --- a/api/graphon/nodes/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from graphon.enums import BuiltinNodeTypes - -__all__ = ["BuiltinNodeTypes"] diff --git a/api/graphon/nodes/answer/__init__.py b/api/graphon/nodes/answer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/nodes/answer/answer_node.py b/api/graphon/nodes/answer/answer_node.py deleted file mode 100644 index c5261a7939..0000000000 --- a/api/graphon/nodes/answer/answer_node.py +++ /dev/null @@ -1,70 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.answer.entities import AnswerNodeData -from graphon.nodes.base.node import Node -from graphon.nodes.base.template import Template -from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from graphon.variables import ArrayFileSegment, FileSegment, Segment - - -class AnswerNode(Node[AnswerNodeData]): - node_type = BuiltinNodeTypes.ANSWER - execution_type = NodeExecutionType.RESPONSE - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - segments = self.graph_runtime_state.variable_pool.convert_template(self.node_data.answer) - files = self._extract_files_from_segments(segments.value) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"answer": segments.markdown, "files": ArrayFileSegment(value=files)}, - ) - - def _extract_files_from_segments(self, segments: Sequence[Segment]): - """Extract all files from segments containing FileSegment or ArrayFileSegment instances. - - FileSegment contains a single file, while ArrayFileSegment contains multiple files. - This method flattens all files into a single list. - """ - files = [] - for segment in segments: - if isinstance(segment, FileSegment): - # Single file - wrap in list for consistency - files.append(segment.value) - elif isinstance(segment, ArrayFileSegment): - # Multiple files - extend the list - files.extend(segment.value) - return files - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: AnswerNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - variable_template_parser = VariableTemplateParser(template=node_data.answer) - variable_selectors = variable_template_parser.extract_variable_selectors() - - variable_mapping = {} - for variable_selector in variable_selectors: - variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector - - return variable_mapping - - def get_streaming_template(self) -> Template: - """ - Get the template for streaming. - - Returns: - Template instance for this Answer node - """ - return Template.from_answer_template(self.node_data.answer) diff --git a/api/graphon/nodes/answer/entities.py b/api/graphon/nodes/answer/entities.py deleted file mode 100644 index c49f1f3895..0000000000 --- a/api/graphon/nodes/answer/entities.py +++ /dev/null @@ -1,67 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum, auto - -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - - -class AnswerNodeData(BaseNodeData): - """ - Answer Node Data. - """ - - type: NodeType = BuiltinNodeTypes.ANSWER - answer: str = Field(..., description="answer template string") - - -class GenerateRouteChunk(BaseModel): - """ - Generate Route Chunk. - """ - - class ChunkType(StrEnum): - VAR = auto() - TEXT = auto() - - type: ChunkType = Field(..., description="generate route chunk type") - - -class VarGenerateRouteChunk(GenerateRouteChunk): - """ - Var Generate Route Chunk. - """ - - type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR - """generate route chunk type""" - value_selector: Sequence[str] = Field(..., description="value selector") - - -class TextGenerateRouteChunk(GenerateRouteChunk): - """ - Text Generate Route Chunk. - """ - - type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT - """generate route chunk type""" - text: str = Field(..., description="text") - - -class AnswerNodeDoubleLink(BaseModel): - node_id: str = Field(..., description="node id") - source_node_ids: list[str] = Field(..., description="source node ids") - target_node_ids: list[str] = Field(..., description="target node ids") - - -class AnswerStreamGenerateRoute(BaseModel): - """ - AnswerStreamGenerateRoute entity - """ - - answer_dependencies: dict[str, list[str]] = Field( - ..., description="answer dependencies (answer node id -> dependent answer node ids)" - ) - answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( - ..., description="answer generate route (answer node id -> generate route chunks)" - ) diff --git a/api/graphon/nodes/base/__init__.py b/api/graphon/nodes/base/__init__.py deleted file mode 100644 index 036e25895d..0000000000 --- a/api/graphon/nodes/base/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState -from .usage_tracking_mixin import LLMUsageTrackingMixin - -__all__ = [ - "BaseIterationNodeData", - "BaseIterationState", - "BaseLoopNodeData", - "BaseLoopState", - "LLMUsageTrackingMixin", -] diff --git a/api/graphon/nodes/base/entities.py b/api/graphon/nodes/base/entities.py deleted file mode 100644 index 94b88c097d..0000000000 --- a/api/graphon/nodes/base/entities.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from enum import StrEnum -from typing import Any - -from pydantic import BaseModel, field_validator - -from graphon.entities.base_node_data import BaseNodeData - - -class VariableSelector(BaseModel): - """ - Variable Selector. - """ - - variable: str - value_selector: Sequence[str] - - -class OutputVariableType(StrEnum): - STRING = "string" - NUMBER = "number" - INTEGER = "integer" - SECRET = "secret" - BOOLEAN = "boolean" - OBJECT = "object" - FILE = "file" - ARRAY = "array" - ARRAY_STRING = "array[string]" - ARRAY_NUMBER = "array[number]" - ARRAY_OBJECT = "array[object]" - ARRAY_BOOLEAN = "array[boolean]" - ARRAY_FILE = "array[file]" - ANY = "any" - ARRAY_ANY = "array[any]" - - -class OutputVariableEntity(BaseModel): - """ - Output Variable Entity. - """ - - variable: str - value_type: OutputVariableType = OutputVariableType.ANY - value_selector: Sequence[str] - - @field_validator("value_type", mode="before") - @classmethod - def normalize_value_type(cls, v: Any) -> Any: - """ - Normalize value_type to handle case-insensitive array types. - Converts 'Array[...]' to 'array[...]' for backward compatibility. - """ - if isinstance(v, str) and v.startswith("Array["): - return v.lower() - return v - - -class BaseIterationNodeData(BaseNodeData): - start_node_id: str | None = None - - -class BaseIterationState(BaseModel): - iteration_node_id: str - index: int - inputs: dict - - class MetaData(BaseModel): - pass - - metadata: MetaData - - -class BaseLoopNodeData(BaseNodeData): - start_node_id: str | None = None - - -class BaseLoopState(BaseModel): - loop_node_id: str - index: int - inputs: dict - - class MetaData(BaseModel): - pass - - metadata: MetaData diff --git a/api/graphon/nodes/base/node.py b/api/graphon/nodes/base/node.py deleted file mode 100644 index 613ff4f037..0000000000 --- a/api/graphon/nodes/base/node.py +++ /dev/null @@ -1,787 +0,0 @@ -from __future__ import annotations - -import logging -import operator -from abc import abstractmethod -from collections.abc import Generator, Mapping, Sequence -from datetime import UTC, datetime -from functools import singledispatchmethod -from types import MappingProxyType -from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin -from uuid import uuid4 - -from graphon.entities import GraphInitParams -from graphon.entities.base_node_data import BaseNodeData, RetryConfig -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ( - ErrorStrategy, - NodeExecutionType, - NodeState, - NodeType, - WorkflowNodeExecutionStatus, -) -from graphon.graph_events import ( - GraphNodeEventBase, - NodeRunAgentLogEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunRetrieverResourceEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, -) -from graphon.node_events import ( - AgentLogEvent, - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - IterationFailedEvent, - IterationNextEvent, - IterationStartedEvent, - IterationSucceededEvent, - LoopFailedEvent, - LoopNextEvent, - LoopStartedEvent, - LoopSucceededEvent, - NodeEventBase, - NodeRunResult, - PauseRequestedEvent, - RunRetrieverResourceEvent, - StreamChunkEvent, - StreamCompletedEvent, - VariableUpdatedEvent, -) -from graphon.runtime import GraphRuntimeState - -NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) -_MISSING_RUN_CONTEXT_VALUE = object() - -logger = logging.getLogger(__name__) - - -class Node(Generic[NodeDataT]): - """BaseNode serves as the foundational class for all node implementations. - - Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output` - attribute to track files generated by the LLM). However, these states are not persisted - when the workflow is suspended or resumed. If a node needs its state to be preserved - across workflow suspension and resumption, it should include the relevant state data - in its output. - """ - - node_type: ClassVar[NodeType] - execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE - _node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData - - def __init_subclass__(cls, **kwargs: Any) -> None: - """ - Automatically extract and validate the node data type from the generic parameter. - - When a subclass is defined as `class MyNode(Node[MyNodeData])`, this method: - 1. Inspects `__orig_bases__` to find the `Node[T]` parameterization - 2. Extracts `T` (e.g., `MyNodeData`) from the generic argument - 3. Validates that `T` is a proper `BaseNodeData` subclass - 4. Stores it in `_node_data_type` for automatic hydration in `__init__` - - This eliminates the need for subclasses to manually implement boilerplate - accessor methods like `_get_title()`, `_get_error_strategy()`, etc. - - How it works: - :: - - class CodeNode(Node[CodeNodeData]): - │ │ - │ └─────────────────────────────────┐ - │ │ - ▼ ▼ - ┌─────────────────────────────┐ ┌─────────────────────────────────┐ - │ __orig_bases__ = ( │ │ CodeNodeData(BaseNodeData) │ - │ Node[CodeNodeData], │ │ title: str │ - │ ) │ │ desc: str | None │ - └──────────────┬──────────────┘ │ ... │ - │ └─────────────────────────────────┘ - ▼ ▲ - ┌─────────────────────────────┐ │ - │ get_origin(base) -> Node │ │ - │ get_args(base) -> ( │ │ - │ CodeNodeData, │ ──────────────────────┘ - │ ) │ - └──────────────┬──────────────┘ - │ - ▼ - ┌─────────────────────────────┐ - │ Validate: │ - │ - Is it a type? │ - │ - Is it a BaseNodeData │ - │ subclass? │ - └──────────────┬──────────────┘ - │ - ▼ - ┌─────────────────────────────┐ - │ cls._node_data_type = │ - │ CodeNodeData │ - └─────────────────────────────┘ - - Later, in __init__: - :: - - config["data"] ──► _node_data_type.model_validate(..., from_attributes=True) - │ - ▼ - CodeNodeData instance - (stored in self._node_data) - - Example: - class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted - node_type = BuiltinNodeTypes.CODE - # No need to implement _get_title, _get_error_strategy, etc. - """ - super().__init_subclass__(**kwargs) - - if cls is Node: - return - - node_data_type = cls._extract_node_data_type_from_generic() - - if node_data_type is None: - raise TypeError(f"{cls.__name__} must inherit from Node[T] with a BaseNodeData subtype") - - cls._node_data_type = node_data_type - - # Skip base class itself - if cls is Node: - return - # Only treat nodes from the base graphon package as production - # registrations. Higher-layer packages may still register subclasses, - # but graphon itself should not know their module identities. - # This prevents test helper subclasses from polluting the global registry and - # accidentally overriding real node types (e.g., a test Answer node). - module_name = getattr(cls, "__module__", "") - # Only register concrete subclasses that define node_type and version() - node_type = cls.node_type - version = cls.version() - bucket = Node._registry.setdefault(node_type, {}) - if module_name.startswith("graphon.nodes."): - # Production node definitions take precedence and may override - bucket[version] = cls # type: ignore[index] - else: - # External/test subclasses may register but must not override production - bucket.setdefault(version, cls) # type: ignore[index] - # Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic - version_keys = [v for v in bucket if v != "latest"] - numeric_pairs: list[tuple[str, int]] = [] - for v in version_keys: - numeric_pairs.append((v, int(v))) - if numeric_pairs: - latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0] - else: - latest_key = max(version_keys) if version_keys else version - bucket["latest"] = bucket[latest_key] - Node._registry_version += 1 - - @classmethod - def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None: - """ - Extract the node data type from the generic parameter `Node[T]`. - - Inspects `__orig_bases__` to find the `Node[T]` parameterization and extracts `T`. - - Returns: - The extracted BaseNodeData subtype, or None if not found. - - Raises: - TypeError: If the generic argument is invalid (not exactly one argument, - or not a BaseNodeData subtype). - """ - # __orig_bases__ contains the original generic bases before type erasure. - # For `class CodeNode(Node[CodeNodeData])`, this would be `(Node[CodeNodeData],)`. - for base in getattr(cls, "__orig_bases__", ()): # type: ignore[attr-defined] - origin = get_origin(base) # Returns `Node` for `Node[CodeNodeData]` - if origin is Node: - args = get_args(base) # Returns `(CodeNodeData,)` for `Node[CodeNodeData]` - if len(args) != 1: - raise TypeError(f"{cls.__name__} must specify exactly one node data generic argument") - - candidate = args[0] - if not isinstance(candidate, type) or not issubclass(candidate, BaseNodeData): - raise TypeError(f"{cls.__name__} must parameterize Node with a BaseNodeData subtype") - - return candidate - - return None - - # Global registry populated via __init_subclass__ - _registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {} - _registry_version: ClassVar[int] = 0 - - @classmethod - def get_registry_version(cls) -> int: - return cls._registry_version - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - ) -> None: - self._graph_init_params = graph_init_params - self._run_context = MappingProxyType(dict(graph_init_params.run_context)) - self.id = id - self.workflow_id = graph_init_params.workflow_id - self.graph_config = graph_init_params.graph_config - self.workflow_call_depth = graph_init_params.call_depth - self.graph_runtime_state = graph_runtime_state - self.state: NodeState = NodeState.UNKNOWN # node execution state - - node_id = config["id"] - - self._node_id = node_id - self._node_execution_id: str = "" - self._start_at = datetime.now(UTC).replace(tzinfo=None) - - self._node_data = self.validate_node_data(config["data"]) - - self.post_init() - - @classmethod - def validate_node_data(cls, node_data: BaseNodeData | Mapping[str, Any]) -> NodeDataT: - """Validate shared graph node payloads against the subclass-declared NodeData model. - - Re-validate from a dumped payload instead of `from_attributes=True` so compatibility - extras stored on `BaseNodeData` survive the handoff to the concrete node data model. - Human Input delivery methods are one such extra field until graphon owns that schema. - """ - if isinstance(node_data, BaseNodeData): - payload = node_data.model_dump(mode="python") - else: - payload = dict(node_data) - return cast(NodeDataT, cls._node_data_type.model_validate(payload)) - - def init_node_data(self, data: BaseNodeData | Mapping[str, Any]) -> None: - """Hydrate `_node_data` for legacy callers that bypass `__init__`.""" - self._node_data = self.validate_node_data(cast(BaseNodeData, data)) - - def post_init(self) -> None: - """Optional hook for subclasses requiring extra initialization.""" - return - - @property - def graph_init_params(self) -> GraphInitParams: - return self._graph_init_params - - @property - def run_context(self) -> Mapping[str, Any]: - return self._run_context - - def get_run_context_value(self, key: str, default: Any = None) -> Any: - return self._run_context.get(key, default) - - def require_run_context_value(self, key: str) -> Any: - value = self.get_run_context_value(key, _MISSING_RUN_CONTEXT_VALUE) - if value is _MISSING_RUN_CONTEXT_VALUE: - raise ValueError(f"run_context missing required key: {key}") - return value - - @property - def execution_id(self) -> str: - return self._node_execution_id - - def ensure_execution_id(self) -> str: - if self._node_execution_id: - return self._node_execution_id - - resumed_execution_id = self._restore_execution_id_from_runtime_state() - if resumed_execution_id: - self._node_execution_id = resumed_execution_id - return self._node_execution_id - - self._node_execution_id = str(uuid4()) - return self._node_execution_id - - def _restore_execution_id_from_runtime_state(self) -> str | None: - graph_execution = self.graph_runtime_state.graph_execution - try: - node_executions = graph_execution.node_executions - except AttributeError: - return None - if not isinstance(node_executions, dict): - return None - node_execution = node_executions.get(self._node_id) - if node_execution is None: - return None - execution_id = node_execution.execution_id - if not execution_id: - return None - return str(execution_id) - - @abstractmethod - def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: - """ - Run node - :return: - """ - raise NotImplementedError - - def populate_start_event(self, event: NodeRunStartedEvent) -> None: - """Allow subclasses to enrich the started event without cross-node imports in the base class.""" - _ = event - - def run(self) -> Generator[GraphNodeEventBase, None, None]: - execution_id = self.ensure_execution_id() - self._start_at = datetime.now(UTC).replace(tzinfo=None) - - # Create and push start event with required fields - start_event = NodeRunStartedEvent( - id=execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.title, - in_iteration_id=None, - start_at=self._start_at, - ) - try: - self.populate_start_event(start_event) - except Exception: - logger.warning("Failed to populate start event for node %s", self._node_id, exc_info=True) - yield start_event - - try: - result = self._run() - - # Handle NodeRunResult - if isinstance(result, NodeRunResult): - yield self._convert_node_run_result_to_graph_node_event(result) - return - - # Handle event stream - for event in result: - # NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase - if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance] - yield self._dispatch(event) - elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance] - event.id = self.execution_id - yield event - else: - yield event - except Exception as e: - logger.exception("Node %s failed to run", self._node_id) - result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - error_type="WorkflowNodeError", - ) - finished_at = datetime.now(UTC).replace(tzinfo=None) - yield NodeRunFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=result, - error=str(e), - ) - - @classmethod - def extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - config: NodeConfigDict, - ) -> Mapping[str, Sequence[str]]: - """Extracts references variable selectors from node configuration. - - The `config` parameter represents the configuration for a specific node type and corresponds - to the `data` field in the node definition object. - - The returned mapping has the following structure: - - {'1747829548239.#1747829667553.result#': ['1747829667553', 'result']} - - For loop and iteration nodes, the mapping may look like this: - - { - "1748332301644.input_selector": ["1748332363630", "result"], - "1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"], - } - - where `1748332301644` is the ID of the loop / iteration node, - and `1748332325079` is the ID of the node inside the loop or iteration node. - - Here, the key consists of two parts: the current node ID (provided as the `node_id` - parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector, - enclosed in `#` symbols. These two parts are separated by a dot (`.`). - - The value is a list of string representing the variable selector, where the first element is the node ID - of the referenced variable, and the second element is the variable name within that node. - - The meaning of the above response is: - - The node with ID `1747829548239` references the variable `result` from the node with - ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a - reference to the `result` output variable of node `1747829667553`. - - :param graph_config: graph config - :param config: node config - :return: - """ - node_id = config["id"] - node_data = cls.validate_node_data(config["data"]) - data = cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, - node_id=node_id, - node_data=node_data, - ) - return data - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: NodeDataT, - ) -> Mapping[str, Sequence[str]]: - return {} - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: - """ - Check if this node blocks the output of specific variables. - - This method is used to determine if a node must complete execution before - the specified variables can be used in streaming output. - - :param variable_selectors: Set of variable selectors, each as a tuple (e.g., ('conversation', 'str')) - :return: True if this node blocks output of any of the specified variables, False otherwise - """ - return False - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return {} - - @classmethod - @abstractmethod - def version(cls) -> str: - """`node_version` returns the version of current node type.""" - # NOTE(QuantumGhost): Node versions must remain unique per `NodeType` so - # registry lookups can resolve numeric versions and `latest`. - raise NotImplementedError("subclasses of BaseNode must implement `version` method.") - - @classmethod - def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]: - """Return a read-only view of the currently registered node classes. - - This accessor intentionally performs no imports. The embedding layer that - owns bootstrap (for example `core.workflow.node_factory`) must import any - extension node packages before calling it so their subclasses register via - `__init_subclass__`. - """ - return {node_type: MappingProxyType(version_map) for node_type, version_map in cls._registry.items()} - - @property - def retry(self) -> bool: - return False - - def _get_error_strategy(self) -> ErrorStrategy | None: - """Get the error strategy for this node.""" - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - """Get the retry configuration for this node.""" - return self._node_data.retry_config - - def _get_title(self) -> str: - """Get the node title.""" - return self._node_data.title - - def _get_description(self) -> str | None: - """Get the node description.""" - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - """Get the default values dictionary for this node.""" - return self._node_data.default_value_dict - - # Public interface properties that delegate to abstract methods - @property - def error_strategy(self) -> ErrorStrategy | None: - """Get the error strategy for this node.""" - return self._get_error_strategy() - - @property - def retry_config(self) -> RetryConfig: - """Get the retry configuration for this node.""" - return self._get_retry_config() - - @property - def title(self) -> str: - """Get the node title.""" - return self._get_title() - - @property - def description(self) -> str | None: - """Get the node description.""" - return self._get_description() - - @property - def default_value_dict(self) -> dict[str, Any]: - """Get the default values dictionary for this node.""" - return self._get_default_value_dict() - - @property - def node_data(self) -> NodeDataT: - """Typed access to this node's configuration data.""" - return self._node_data - - def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase: - finished_at = datetime.now(UTC).replace(tzinfo=None) - match result.status: - case WorkflowNodeExecutionStatus.FAILED: - return NodeRunFailedEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=result, - error=result.error, - ) - case WorkflowNodeExecutionStatus.SUCCEEDED: - return NodeRunSucceededEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=result, - ) - case _: - raise Exception(f"result status {result.status} not supported") - - @singledispatchmethod - def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase: - raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}") - - @_dispatch.register - def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: - return NodeRunStreamChunkEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - selector=event.selector, - chunk=event.chunk, - is_final=event.is_final, - ) - - @_dispatch.register - def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent: - finished_at = datetime.now(UTC).replace(tzinfo=None) - match event.node_run_result.status: - case WorkflowNodeExecutionStatus.SUCCEEDED: - return NodeRunSucceededEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=event.node_run_result, - ) - case WorkflowNodeExecutionStatus.FAILED: - return NodeRunFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=event.node_run_result, - error=event.node_run_result.error, - ) - case _: - raise NotImplementedError( - f"Node {self._node_id} does not support status {event.node_run_result.status}" - ) - - @_dispatch.register - def _(self, event: VariableUpdatedEvent) -> NodeRunVariableUpdatedEvent: - return NodeRunVariableUpdatedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - variable=event.variable, - ) - - @_dispatch.register - def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent: - return NodeRunPauseRequestedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED), - reason=event.reason, - ) - - @_dispatch.register - def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent: - return NodeRunAgentLogEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - message_id=event.message_id, - label=event.label, - node_execution_id=event.node_execution_id, - parent_id=event.parent_id, - error=event.error, - status=event.status, - data=event.data, - metadata=event.metadata, - ) - - @_dispatch.register - def _(self, event: HumanInputFormFilledEvent): - return NodeRunHumanInputFormFilledEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=event.node_title, - rendered_content=event.rendered_content, - action_id=event.action_id, - action_text=event.action_text, - ) - - @_dispatch.register - def _(self, event: HumanInputFormTimeoutEvent): - return NodeRunHumanInputFormTimeoutEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=event.node_title, - expiration_time=event.expiration_time, - ) - - @_dispatch.register - def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: - return NodeRunLoopStartedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - metadata=event.metadata, - predecessor_node_id=event.predecessor_node_id, - ) - - @_dispatch.register - def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent: - return NodeRunLoopNextEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - index=event.index, - pre_loop_output=event.pre_loop_output, - ) - - @_dispatch.register - def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent: - return NodeRunLoopSucceededEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - ) - - @_dispatch.register - def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent: - return NodeRunLoopFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - error=event.error, - ) - - @_dispatch.register - def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent: - return NodeRunIterationStartedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - metadata=event.metadata, - predecessor_node_id=event.predecessor_node_id, - ) - - @_dispatch.register - def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent: - return NodeRunIterationNextEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - index=event.index, - pre_iteration_output=event.pre_iteration_output, - ) - - @_dispatch.register - def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent: - return NodeRunIterationSucceededEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - ) - - @_dispatch.register - def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent: - return NodeRunIterationFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - error=event.error, - ) - - @_dispatch.register - def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: - return NodeRunRetrieverResourceEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - retriever_resources=event.retriever_resources, - context=event.context, - node_version=self.version(), - ) diff --git a/api/graphon/nodes/base/template.py b/api/graphon/nodes/base/template.py deleted file mode 100644 index 311de4a6ea..0000000000 --- a/api/graphon/nodes/base/template.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Template structures for Response nodes (Answer and End). - -This module provides a unified template structure for both Answer and End nodes, -similar to SegmentGroup but focused on template representation without values. -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Any, Union - -from graphon.nodes.base.variable_template_parser import VariableTemplateParser - - -@dataclass(frozen=True) -class TemplateSegment(ABC): - """Base class for template segments.""" - - @abstractmethod - def __str__(self) -> str: - """String representation of the segment.""" - pass - - -@dataclass(frozen=True) -class TextSegment(TemplateSegment): - """A text segment in a template.""" - - text: str - - def __str__(self) -> str: - return self.text - - -@dataclass(frozen=True) -class VariableSegment(TemplateSegment): - """A variable reference segment in a template.""" - - selector: Sequence[str] - variable_name: str | None = None # Optional variable name for End nodes - - def __str__(self) -> str: - return "{{#" + ".".join(self.selector) + "#}}" - - -# Type alias for segments -TemplateSegmentUnion = Union[TextSegment, VariableSegment] - - -@dataclass(frozen=True) -class Template: - """Unified template structure for Response nodes. - - Similar to SegmentGroup, but represents the template structure - without variable values - only marking variable selectors. - """ - - segments: list[TemplateSegmentUnion] - - @classmethod - def from_answer_template(cls, template_str: str) -> Template: - """Create a Template from an Answer node template string. - - Example: - "Hello, {{#node1.name#}}" -> [TextSegment("Hello, "), VariableSegment(["node1", "name"])] - - Args: - template_str: The answer template string - - Returns: - Template instance - """ - parser = VariableTemplateParser(template_str) - segments: list[TemplateSegmentUnion] = [] - - # Extract variable selectors to find all variables - variable_selectors = parser.extract_variable_selectors() - var_map = {var.variable: var.value_selector for var in variable_selectors} - - # Parse template to get ordered segments - # We need to split the template by variable placeholders while preserving order - import re - - # Create a regex pattern that matches variable placeholders - pattern = r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}" - - # Split template while keeping the delimiters (variable placeholders) - parts = re.split(pattern, template_str) - - for i, part in enumerate(parts): - if not part: - continue - - # Check if this part is a variable reference (odd indices after split) - if i % 2 == 1: # Odd indices are variable keys - # Remove the # symbols from the variable key - var_key = part - if var_key in var_map: - segments.append(VariableSegment(selector=list(var_map[var_key]))) - else: - # This shouldn't happen with valid templates - segments.append(TextSegment(text="{{" + part + "}}")) - else: - # Even indices are text segments - segments.append(TextSegment(text=part)) - - return cls(segments=segments) - - @classmethod - def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template: - """Create a Template from an End node outputs configuration. - - End nodes are treated as templates of concatenated variables with newlines. - - Example: - [{"variable": "text", "value_selector": ["node1", "text"]}, - {"variable": "result", "value_selector": ["node2", "result"]}] - -> - [VariableSegment(["node1", "text"]), - TextSegment("\n"), - VariableSegment(["node2", "result"])] - - Args: - outputs_config: List of output configurations with variable and value_selector - - Returns: - Template instance - """ - segments: list[TemplateSegmentUnion] = [] - - for i, output in enumerate(outputs_config): - if i > 0: - # Add newline separator between variables - segments.append(TextSegment(text="\n")) - - value_selector = output.get("value_selector", []) - variable_name = output.get("variable", "") - if value_selector: - segments.append(VariableSegment(selector=list(value_selector), variable_name=variable_name)) - - if len(segments) > 0 and isinstance(segments[-1], TextSegment): - segments = segments[:-1] - - return cls(segments=segments) - - def __str__(self) -> str: - """String representation of the template.""" - return "".join(str(segment) for segment in self.segments) diff --git a/api/graphon/nodes/base/usage_tracking_mixin.py b/api/graphon/nodes/base/usage_tracking_mixin.py deleted file mode 100644 index 955bfe6726..0000000000 --- a/api/graphon/nodes/base/usage_tracking_mixin.py +++ /dev/null @@ -1,28 +0,0 @@ -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState - - -class LLMUsageTrackingMixin: - """Provides shared helpers for merging and recording LLM usage within workflow nodes.""" - - graph_runtime_state: GraphRuntimeState - - @staticmethod - def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage: - """Return a combined usage snapshot, preserving zero-value inputs.""" - if new_usage is None or new_usage.total_tokens <= 0: - return current - if current.total_tokens == 0: - return new_usage - return current.plus(new_usage) - - def _accumulate_usage(self, usage: LLMUsage) -> None: - """Push usage into the graph runtime accumulator for downstream reporting.""" - if usage.total_tokens <= 0: - return - - current_usage = self.graph_runtime_state.llm_usage - if current_usage.total_tokens == 0: - self.graph_runtime_state.llm_usage = usage.model_copy() - else: - self.graph_runtime_state.llm_usage = current_usage.plus(usage) diff --git a/api/graphon/nodes/base/variable_template_parser.py b/api/graphon/nodes/base/variable_template_parser.py deleted file mode 100644 index de5e619e8c..0000000000 --- a/api/graphon/nodes/base/variable_template_parser.py +++ /dev/null @@ -1,130 +0,0 @@ -import re -from collections.abc import Mapping, Sequence -from typing import Any - -from .entities import VariableSelector - -REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") - -SELECTOR_PATTERN = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") - - -def extract_selectors_from_template(template: str, /) -> Sequence[VariableSelector]: - parts = SELECTOR_PATTERN.split(template) - selectors = [] - for part in filter(lambda x: x, parts): - if "." in part and part[0] == "#" and part[-1] == "#": - selectors.append(VariableSelector(variable=f"{part}", value_selector=part[1:-1].split("."))) - return selectors - - -class VariableTemplateParser: - """ - !NOTE: Consider to use the new `segments` module instead of this class. - - A class for parsing and manipulating template variables in a string. - - Rules: - - 1. Template variables must be enclosed in `{{}}`. - 2. The template variable Key can only be: #node_id.var1.var2#. - 3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2. - - Example usage: - - template = "Hello, {{#node_id.query.name#}}! Your age is {{#node_id.query.age#}}." - parser = VariableTemplateParser(template) - - # Extract template variable keys - variable_keys = parser.extract() - print(variable_keys) - # Output: ['#node_id.query.name#', '#node_id.query.age#'] - - # Extract variable selectors - variable_selectors = parser.extract_variable_selectors() - print(variable_selectors) - # Output: [VariableSelector(variable='#node_id.query.name#', value_selector=['node_id', 'query', 'name']), - # VariableSelector(variable='#node_id.query.age#', value_selector=['node_id', 'query', 'age'])] - - # Format the template string - inputs = {'#node_id.query.name#': 'John', '#node_id.query.age#': 25}} - formatted_string = parser.format(inputs) - print(formatted_string) - # Output: "Hello, John! Your age is 25." - """ - - def __init__(self, template: str): - self.template = template - self.variable_keys = self.extract() - - def extract(self): - """ - Extracts all the template variable keys from the template string. - - Returns: - A list of template variable keys. - """ - # Regular expression to match the template rules - matches = re.findall(REGEX, self.template) - - first_group_matches = [match[0] for match in matches] - - return list(set(first_group_matches)) - - def extract_variable_selectors(self) -> list[VariableSelector]: - """ - Extracts the variable selectors from the template variable keys. - - Returns: - A list of VariableSelector objects representing the variable selectors. - """ - variable_selectors = [] - for variable_key in self.variable_keys: - remove_hash = variable_key.replace("#", "") - split_result = remove_hash.split(".") - if len(split_result) < 2: - continue - - variable_selectors.append(VariableSelector(variable=variable_key, value_selector=split_result)) - - return variable_selectors - - def format(self, inputs: Mapping[str, Any]) -> str: - """ - Formats the template string by replacing the template variables with their corresponding values. - - Args: - inputs: A dictionary containing the values for the template variables. - - Returns: - The formatted string with template variables replaced by their values. - """ - - def replacer(match): - key = match.group(1) - value = inputs.get(key, match.group(0)) # return original matched string if key not found - - if value is None: - value = "" - # convert the value to string - if isinstance(value, list | dict | bool | int | float): - value = str(value) - - # remove template variables if required - return VariableTemplateParser.remove_template_variables(value) - - prompt = re.sub(REGEX, replacer, self.template) - return re.sub(r"<\|.*?\|>", "", prompt) - - @classmethod - def remove_template_variables(cls, text: str): - """ - Removes the template variables from the given text. - - Args: - text: The text from which to remove the template variables. - - Returns: - The text with template variables removed. - """ - return re.sub(REGEX, r"{\1}", text) diff --git a/api/graphon/nodes/code/__init__.py b/api/graphon/nodes/code/__init__.py deleted file mode 100644 index 8c6dcc7fcc..0000000000 --- a/api/graphon/nodes/code/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .code_node import CodeNode - -__all__ = ["CodeNode"] diff --git a/api/graphon/nodes/code/code_node.py b/api/graphon/nodes/code/code_node.py deleted file mode 100644 index c2eea0bec1..0000000000 --- a/api/graphon/nodes/code/code_node.py +++ /dev/null @@ -1,493 +0,0 @@ -from collections.abc import Mapping, Sequence -from decimal import Decimal -from textwrap import dedent -from typing import TYPE_CHECKING, Any, Protocol, cast - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.code.entities import CodeLanguage, CodeNodeData -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.variables.segments import ArrayFileSegment -from graphon.variables.types import SegmentType - -from .exc import ( - CodeNodeError, - DepthLimitError, - OutputValidationError, -) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - - -class WorkflowCodeExecutor(Protocol): - def execute( - self, - *, - language: CodeLanguage, - code: str, - inputs: Mapping[str, Any], - ) -> Mapping[str, Any]: ... - - def is_execution_error(self, error: Exception) -> bool: ... - - -def _build_default_config(*, language: CodeLanguage, code: str) -> Mapping[str, object]: - return { - "type": "code", - "config": { - "variables": [ - {"variable": "arg1", "value_selector": []}, - {"variable": "arg2", "value_selector": []}, - ], - "code_language": language, - "code": code, - "outputs": {"result": {"type": "string", "children": None}}, - }, - } - - -_DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = { - CodeLanguage.PYTHON3: dedent( - """ - def main(arg1: str, arg2: str): - return { - "result": arg1 + arg2, - } - """ - ), - CodeLanguage.JAVASCRIPT: dedent( - """ - function main({arg1, arg2}) { - return { - result: arg1 + arg2 - } - } - """ - ), -} - - -class CodeNode(Node[CodeNodeData]): - node_type = BuiltinNodeTypes.CODE - _limits: CodeNodeLimits - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - code_executor: WorkflowCodeExecutor, - code_limits: CodeNodeLimits, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._code_executor: WorkflowCodeExecutor = code_executor - self._limits = code_limits - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ - code_language = CodeLanguage.PYTHON3 - if filters: - code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3)) - - default_code = _DEFAULT_CODE_BY_LANGUAGE.get(code_language) - if default_code is None: - raise CodeNodeError(f"Unsupported code language: {code_language}") - return _build_default_config(language=code_language, code=default_code) - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - # Get code language - code_language = self.node_data.code_language - code = self.node_data.code - - # Get variables - variables = {} - for variable_selector in self.node_data.variables: - variable_name = variable_selector.variable - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if isinstance(variable, ArrayFileSegment): - variables[variable_name] = [v.to_dict() for v in variable.value] if variable.value else None - else: - variables[variable_name] = variable.to_object() if variable else None - # Run code - try: - result = self._code_executor.execute( - language=code_language, - code=code, - inputs=variables, - ) - - # Transform result - result = self._transform_result(result=result, output_schema=self.node_data.outputs) - except CodeNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ - ) - except Exception as e: - if not self._code_executor.is_execution_error(e): - raise - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ - ) - - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) - - def _check_string(self, value: str | None, variable: str) -> str | None: - """ - Check string - :param value: value - :param variable: variable - :return: - """ - if value is None: - return None - - if len(value) > self._limits.max_string_length: - raise OutputValidationError( - f"The length of output variable `{variable}` must be" - f" less than {self._limits.max_string_length} characters" - ) - - return value.replace("\x00", "") - - def _check_boolean(self, value: bool | None, variable: str) -> bool | None: - if value is None: - return None - - return value - - def _check_number(self, value: int | float | None, variable: str) -> int | float | None: - """ - Check number - :param value: value - :param variable: variable - :return: - """ - if value is None: - return None - - if value > self._limits.max_number or value < self._limits.min_number: - raise OutputValidationError( - f"Output variable `{variable}` is out of range," - f" it must be between {self._limits.min_number} and {self._limits.max_number}." - ) - - if isinstance(value, float): - decimal_value = Decimal(str(value)).normalize() - precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator] - # raise error if precision is too high - if precision > self._limits.max_precision: - raise OutputValidationError( - f"Output variable `{variable}` has too high precision," - f" it must be less than {self._limits.max_precision} digits." - ) - - return value - - def _transform_result( - self, - result: Mapping[str, Any], - output_schema: dict[str, CodeNodeData.Output] | None, - prefix: str = "", - depth: int = 1, - ): - # TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes. - # Note that `_transform_result` may produce lists containing `None` values, - # which don't conform to the type requirements of `Array*Segment` classes. - if depth > self._limits.max_depth: - raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.") - - transformed_result: dict[str, Any] = {} - if output_schema is None: - # validate output thought instance type - for output_name, output_value in result.items(): - if isinstance(output_value, dict): - self._transform_result( - result=output_value, - output_schema=None, - prefix=f"{prefix}.{output_name}" if prefix else output_name, - depth=depth + 1, - ) - elif isinstance(output_value, bool): - self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name) - elif isinstance(output_value, int | float): - self._check_number( - value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name - ) - elif isinstance(output_value, str): - self._check_string( - value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name - ) - elif isinstance(output_value, list): - first_element = output_value[0] if len(output_value) > 0 else None - if first_element is not None: - if isinstance(first_element, int | float) and all( - value is None or isinstance(value, int | float) for value in output_value - ): - for i, value in enumerate(output_value): - self._check_number( - value=value, - variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", - ) - elif isinstance(first_element, str) and all( - value is None or isinstance(value, str) for value in output_value - ): - for i, value in enumerate(output_value): - self._check_string( - value=value, - variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", - ) - elif ( - isinstance(first_element, dict) - and all(value is None or isinstance(value, dict) for value in output_value) - or isinstance(first_element, list) - and all(value is None or isinstance(value, list) for value in output_value) - ): - for i, value in enumerate(output_value): - if value is not None: - self._transform_result( - result=value, - output_schema=None, - prefix=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", - depth=depth + 1, - ) - else: - raise OutputValidationError( - f"Output {prefix}.{output_name} is not a valid array." - f" make sure all elements are of the same type." - ) - elif output_value is None: - pass - else: - raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.") - - return result - - parameters_validated = {} - for output_name, output_config in output_schema.items(): - dot = "." if prefix else "" - if output_name not in result: - raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.") - - if output_config.type == SegmentType.OBJECT: - # check if output is object - if not isinstance(result.get(output_name), dict): - if result[output_name] is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an object," - f" got {type(result.get(output_name))} instead." - ) - else: - transformed_result[output_name] = self._transform_result( - result=result[output_name], - output_schema=output_config.children, - prefix=f"{prefix}.{output_name}", - depth=depth + 1, - ) - elif output_config.type == SegmentType.NUMBER: - # check if number available - value = result.get(output_name) - if value is not None and not isinstance(value, (int, float)): - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not a number," - f" got {type(result.get(output_name))} instead." - ) - checked = self._check_number(value=value, variable=f"{prefix}{dot}{output_name}") - # If the output is a boolean and the output schema specifies a NUMBER type, - # convert the boolean value to an integer. - # - # This ensures compatibility with existing workflows that may use - # `True` and `False` as values for NUMBER type outputs. - transformed_result[output_name] = self._convert_boolean_to_int(checked) - - elif output_config.type == SegmentType.STRING: - # check if string available - value = result.get(output_name) - if value is not None and not isinstance(value, str): - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} must be a string, got {type(value).__name__} instead" - ) - transformed_result[output_name] = self._check_string( - value=value, - variable=f"{prefix}{dot}{output_name}", - ) - elif output_config.type == SegmentType.BOOLEAN: - transformed_result[output_name] = self._check_boolean( - value=result[output_name], - variable=f"{prefix}{dot}{output_name}", - ) - elif output_config.type == SegmentType.ARRAY_NUMBER: - # check if array of number available - value = result[output_name] - if not isinstance(value, list): - if value is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead." - ) - else: - if len(value) > self._limits.max_number_array_length: - raise OutputValidationError( - f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {self._limits.max_number_array_length} elements." - ) - - for i, inner_value in enumerate(value): - if not isinstance(inner_value, (int, float)): - raise OutputValidationError( - f"The element at index {i} of output variable `{prefix}{dot}{output_name}` must be" - f" a number." - ) - _ = self._check_number(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]") - transformed_result[output_name] = [ - # If the element is a boolean and the output schema specifies a `array[number]` type, - # convert the boolean value to an integer. - # - # This ensures compatibility with existing workflows that may use - # `True` and `False` as values for NUMBER type outputs. - self._convert_boolean_to_int(v) - for v in value - ] - elif output_config.type == SegmentType.ARRAY_STRING: - # check if array of string available - if not isinstance(result[output_name], list): - if result[output_name] is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array," - f" got {type(result.get(output_name))} instead." - ) - else: - if len(result[output_name]) > self._limits.max_string_array_length: - raise OutputValidationError( - f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {self._limits.max_string_array_length} elements." - ) - - transformed_result[output_name] = [ - self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") - for i, value in enumerate(result[output_name]) - ] - elif output_config.type == SegmentType.ARRAY_OBJECT: - # check if array of object available - if not isinstance(result[output_name], list): - if result[output_name] is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array," - f" got {type(result.get(output_name))} instead." - ) - else: - if len(result[output_name]) > self._limits.max_object_array_length: - raise OutputValidationError( - f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {self._limits.max_object_array_length} elements." - ) - - for i, value in enumerate(result[output_name]): - if not isinstance(value, dict): - if value is None: - pass - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name}[{i}] is not an object," - f" got {type(value)} instead at index {i}." - ) - - transformed_result[output_name] = [ - None - if value is None - else self._transform_result( - result=value, - output_schema=output_config.children, - prefix=f"{prefix}{dot}{output_name}[{i}]", - depth=depth + 1, - ) - for i, value in enumerate(result[output_name]) - ] - elif output_config.type == SegmentType.ARRAY_BOOLEAN: - # check if array of object available - value = result[output_name] - if not isinstance(value, list): - if value is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array," - f" got {type(result.get(output_name))} instead." - ) - else: - for i, inner_value in enumerate(value): - if inner_value is not None and not isinstance(inner_value, bool): - raise OutputValidationError( - f"Output {prefix}{dot}{output_name}[{i}] is not a boolean," - f" got {type(inner_value)} instead." - ) - _ = self._check_boolean(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]") - transformed_result[output_name] = value - - else: - raise OutputValidationError(f"Output type {output_config.type} is not supported.") - - parameters_validated[output_name] = True - - # check if all output parameters are validated - if len(parameters_validated) != len(result): - raise CodeNodeError("Not all output parameters are validated.") - - return transformed_result - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: CodeNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - return { - node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables - } - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled - - @staticmethod - def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None: - """This function convert boolean to integers when the output schema specifies a NUMBER type. - - This ensures compatibility with existing workflows that may use - `True` and `False` as values for NUMBER type outputs. - """ - if value is None: - return None - if isinstance(value, bool): - return int(value) - return value diff --git a/api/graphon/nodes/code/entities.py b/api/graphon/nodes/code/entities.py deleted file mode 100644 index dc89d64495..0000000000 --- a/api/graphon/nodes/code/entities.py +++ /dev/null @@ -1,57 +0,0 @@ -from enum import StrEnum -from typing import Annotated, Literal - -from pydantic import AfterValidator, BaseModel - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base.entities import VariableSelector -from graphon.variables.types import SegmentType - - -class CodeLanguage(StrEnum): - PYTHON3 = "python3" - JINJA2 = "jinja2" - JAVASCRIPT = "javascript" - - -_ALLOWED_OUTPUT_FROM_CODE = frozenset( - [ - SegmentType.STRING, - SegmentType.NUMBER, - SegmentType.OBJECT, - SegmentType.BOOLEAN, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_BOOLEAN, - ] -) - - -def _validate_type(segment_type: SegmentType) -> SegmentType: - if segment_type not in _ALLOWED_OUTPUT_FROM_CODE: - raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}") - return segment_type - - -class CodeNodeData(BaseNodeData): - """ - Code Node Data. - """ - - type: NodeType = BuiltinNodeTypes.CODE - - class Output(BaseModel): - type: Annotated[SegmentType, AfterValidator(_validate_type)] - children: dict[str, "CodeNodeData.Output"] | None = None - - class Dependency(BaseModel): - name: str - version: str - - variables: list[VariableSelector] - code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] - code: str - outputs: dict[str, Output] - dependencies: list[Dependency] | None = None diff --git a/api/graphon/nodes/code/exc.py b/api/graphon/nodes/code/exc.py deleted file mode 100644 index d6334fd554..0000000000 --- a/api/graphon/nodes/code/exc.py +++ /dev/null @@ -1,16 +0,0 @@ -class CodeNodeError(ValueError): - """Base class for code node errors.""" - - pass - - -class OutputValidationError(CodeNodeError): - """Raised when there is an output validation error.""" - - pass - - -class DepthLimitError(CodeNodeError): - """Raised when the depth limit is reached.""" - - pass diff --git a/api/graphon/nodes/code/limits.py b/api/graphon/nodes/code/limits.py deleted file mode 100644 index a6b9e9e68e..0000000000 --- a/api/graphon/nodes/code/limits.py +++ /dev/null @@ -1,13 +0,0 @@ -from dataclasses import dataclass - - -@dataclass(frozen=True) -class CodeNodeLimits: - max_string_length: int - max_number: int | float - min_number: int | float - max_precision: int - max_depth: int - max_number_array_length: int - max_string_array_length: int - max_object_array_length: int diff --git a/api/graphon/nodes/document_extractor/__init__.py b/api/graphon/nodes/document_extractor/__init__.py deleted file mode 100644 index 9922e3949d..0000000000 --- a/api/graphon/nodes/document_extractor/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .entities import DocumentExtractorNodeData, UnstructuredApiConfig -from .node import DocumentExtractorNode - -__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData", "UnstructuredApiConfig"] diff --git a/api/graphon/nodes/document_extractor/entities.py b/api/graphon/nodes/document_extractor/entities.py deleted file mode 100644 index 026a0cd224..0000000000 --- a/api/graphon/nodes/document_extractor/entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from collections.abc import Sequence -from dataclasses import dataclass - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - - -class DocumentExtractorNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.DOCUMENT_EXTRACTOR - variable_selector: Sequence[str] - - -@dataclass(frozen=True) -class UnstructuredApiConfig: - api_url: str | None = None - api_key: str = "" diff --git a/api/graphon/nodes/document_extractor/exc.py b/api/graphon/nodes/document_extractor/exc.py deleted file mode 100644 index 5caf00ebc5..0000000000 --- a/api/graphon/nodes/document_extractor/exc.py +++ /dev/null @@ -1,14 +0,0 @@ -class DocumentExtractorError(ValueError): - """Base exception for errors related to the DocumentExtractorNode.""" - - -class FileDownloadError(DocumentExtractorError): - """Exception raised when there's an error downloading a file.""" - - -class UnsupportedFileTypeError(DocumentExtractorError): - """Exception raised when trying to extract text from an unsupported file type.""" - - -class TextExtractionError(DocumentExtractorError): - """Exception raised when there's an error during text extraction from a file.""" diff --git a/api/graphon/nodes/document_extractor/node.py b/api/graphon/nodes/document_extractor/node.py deleted file mode 100644 index be46481e7d..0000000000 --- a/api/graphon/nodes/document_extractor/node.py +++ /dev/null @@ -1,782 +0,0 @@ -import csv -import io -import json -import logging -import os -import tempfile -import zipfile -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -import charset_normalizer -import docx -import pandas as pd -import pypandoc -import pypdfium2 -import webvtt -import yaml -from docx.document import Document -from docx.oxml.table import CT_Tbl -from docx.oxml.text.paragraph import CT_P -from docx.table import Table -from docx.text.paragraph import Paragraph - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, file_manager -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.protocols import HttpClientProtocol -from graphon.variables import ArrayFileSegment -from graphon.variables.segments import ArrayStringSegment, FileSegment - -from .entities import DocumentExtractorNodeData, UnstructuredApiConfig -from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - - -class DocumentExtractorNode(Node[DocumentExtractorNodeData]): - """ - Extracts text content from various file types. - Supports plain text, PDF, and DOC/DOCX files. - """ - - node_type = BuiltinNodeTypes.DOCUMENT_EXTRACTOR - - @classmethod - def version(cls) -> str: - return "1" - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - unstructured_api_config: UnstructuredApiConfig | None = None, - http_client: HttpClientProtocol, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig() - self._http_client = http_client - - def _run(self): - variable_selector = self.node_data.variable_selector - variable = self.graph_runtime_state.variable_pool.get(variable_selector) - - if variable is None: - error_message = f"File variable not found for selector: {variable_selector}" - return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) - if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment): - error_message = f"Variable {variable_selector} is not an ArrayFileSegment" - return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) - - value = variable.value - inputs = {"variable_selector": variable_selector} - if isinstance(value, list): - value = list(filter(lambda x: x, value)) - process_data = {"documents": value if isinstance(value, list) else [value]} - - if not value: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={"text": ArrayStringSegment(value=[])}, - ) - - try: - if isinstance(value, list): - extracted_text_list = [ - _extract_text_from_file( - self._http_client, file, unstructured_api_config=self._unstructured_api_config - ) - for file in value - ] - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={"text": ArrayStringSegment(value=extracted_text_list)}, - ) - elif isinstance(value, File): - extracted_text = _extract_text_from_file( - self._http_client, value, unstructured_api_config=self._unstructured_api_config - ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={"text": extracted_text}, - ) - else: - raise DocumentExtractorError(f"Unsupported variable type: {type(value)}") - except DocumentExtractorError as e: - logger.warning(e, exc_info=True) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=inputs, - process_data=process_data, - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: DocumentExtractorNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - return {node_id + ".files": node_data.variable_selector} - - -def _extract_text_by_mime_type( - *, - file_content: bytes, - mime_type: str, - unstructured_api_config: UnstructuredApiConfig, -) -> str: - """Extract text from a file based on its MIME type.""" - match mime_type: - case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml": - return _extract_text_from_plain_text(file_content) - case "application/pdf": - return _extract_text_from_pdf(file_content) - case "application/msword": - return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config) - case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": - return _extract_text_from_docx(file_content) - case "text/csv": - return _extract_text_from_csv(file_content) - case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel": - return _extract_text_from_excel(file_content) - case "application/vnd.ms-powerpoint": - return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config) - case "application/vnd.openxmlformats-officedocument.presentationml.presentation": - return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config) - case "application/epub+zip": - return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config) - case "message/rfc822": - return _extract_text_from_eml(file_content) - case "application/vnd.ms-outlook": - return _extract_text_from_msg(file_content) - case "application/json": - return _extract_text_from_json(file_content) - case "application/x-yaml" | "text/yaml": - return _extract_text_from_yaml(file_content) - case "text/vtt": - return _extract_text_from_vtt(file_content) - case "text/properties": - return _extract_text_from_properties(file_content) - case _: - raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") - - -def _extract_text_by_file_extension( - *, - file_content: bytes, - file_extension: str, - unstructured_api_config: UnstructuredApiConfig, -) -> str: - """Extract text from a file based on its file extension.""" - match file_extension: - case ( - ".txt" - | ".markdown" - | ".md" - | ".mdx" - | ".html" - | ".htm" - | ".xml" - | ".c" - | ".h" - | ".cpp" - | ".hpp" - | ".cc" - | ".cxx" - | ".c++" - | ".py" - | ".js" - | ".ts" - | ".jsx" - | ".tsx" - | ".java" - | ".php" - | ".rb" - | ".go" - | ".rs" - | ".swift" - | ".kt" - | ".scala" - | ".sh" - | ".bash" - | ".bat" - | ".ps1" - | ".sql" - | ".r" - | ".m" - | ".pl" - | ".lua" - | ".vim" - | ".asm" - | ".s" - | ".css" - | ".scss" - | ".less" - | ".sass" - | ".ini" - | ".cfg" - | ".conf" - | ".toml" - | ".env" - | ".log" - | ".vtt" - ): - return _extract_text_from_plain_text(file_content) - case ".json": - return _extract_text_from_json(file_content) - case ".yaml" | ".yml": - return _extract_text_from_yaml(file_content) - case ".pdf": - return _extract_text_from_pdf(file_content) - case ".doc": - return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config) - case ".docx": - return _extract_text_from_docx(file_content) - case ".csv": - return _extract_text_from_csv(file_content) - case ".xls" | ".xlsx": - return _extract_text_from_excel(file_content) - case ".ppt": - return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config) - case ".pptx": - return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config) - case ".epub": - return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config) - case ".eml": - return _extract_text_from_eml(file_content) - case ".msg": - return _extract_text_from_msg(file_content) - case ".properties": - return _extract_text_from_properties(file_content) - case _: - raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}") - - -def _extract_text_from_plain_text(file_content: bytes) -> str: - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content, cp_isolation=["utf_8", "latin_1", "cp1252"]).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - return file_content.decode(encoding, errors="ignore") - except (UnicodeDecodeError, LookupError) as e: - # If decoding fails, try with utf-8 as last resort - try: - return file_content.decode("utf-8", errors="ignore") - except UnicodeDecodeError: - raise TextExtractionError(f"Failed to decode plain text file: {e}") from e - - -def _extract_text_from_json(file_content: bytes) -> str: - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - json_data = json.loads(file_content.decode(encoding, errors="ignore")) - return json.dumps(json_data, indent=2, ensure_ascii=False) - except (UnicodeDecodeError, LookupError, json.JSONDecodeError) as e: - # If decoding fails, try with utf-8 as last resort - try: - json_data = json.loads(file_content.decode("utf-8", errors="ignore")) - return json.dumps(json_data, indent=2, ensure_ascii=False) - except (UnicodeDecodeError, json.JSONDecodeError): - raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e - - -def _extract_text_from_yaml(file_content: bytes) -> str: - """Extract the content from yaml file""" - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore")) - return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) - except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e: - # If decoding fails, try with utf-8 as last resort - try: - yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore")) - return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) - except (UnicodeDecodeError, yaml.YAMLError): - raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e - - -def _extract_text_from_pdf(file_content: bytes) -> str: - try: - pdf_file = io.BytesIO(file_content) - pdf_document = pypdfium2.PdfDocument(pdf_file, autoclose=True) - text = "" - for page in pdf_document: - text_page = page.get_textpage() - text += text_page.get_text_range() - text_page.close() - page.close() - return text - except Exception as e: - raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e - - -def _extract_text_from_doc(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - """ - Extract text from a DOC file. - """ - from unstructured.partition.api import partition_via_api - - if not unstructured_api_config.api_url: - raise TextExtractionError("Unstructured API URL is not configured for DOC file processing.") - api_key = unstructured_api_config.api_key or "" - - try: - with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - return "\n".join([getattr(element, "text", "") for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from DOC: {str(e)}") from e - - -def parser_docx_part(block, doc: Document, content_items, i): - if isinstance(block, CT_P): - content_items.append((i, "paragraph", Paragraph(block, doc))) - elif isinstance(block, CT_Tbl): - content_items.append((i, "table", Table(block, doc))) - - -def _normalize_docx_zip(file_content: bytes) -> bytes: - """ - Some DOCX files (e.g. exported by Evernote on Windows) are malformed: - ZIP entry names use backslash (\\) as path separator instead of the forward - slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry - "word\\document.xml" is never found when python-docx looks for - "word/document.xml", which triggers a KeyError about a missing relationship. - - This function rewrites the ZIP in-memory, normalizing all entry names to - use forward slashes without touching any actual document content. - """ - try: - with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin: - out_buf = io.BytesIO() - with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout: - for item in zin.infolist(): - data = zin.read(item.filename) - # Normalize backslash path separators to forward slash - item.filename = item.filename.replace("\\", "/") - zout.writestr(item, data) - return out_buf.getvalue() - except zipfile.BadZipFile: - # Not a valid zip — return as-is and let python-docx report the real error - return file_content - - -def _extract_text_from_docx(file_content: bytes) -> str: - """ - Extract text from a DOCX file. - For now support only paragraph and table add more if needed - """ - try: - doc_file = io.BytesIO(file_content) - try: - doc = docx.Document(doc_file) - except Exception as e: - logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e) - # Some DOCX files exported by tools like Evernote on Windows use - # backslash path separators in ZIP entries and/or single-quoted XML - # attributes, both of which break python-docx on Linux. Normalize and retry. - file_content = _normalize_docx_zip(file_content) - doc = docx.Document(io.BytesIO(file_content)) - text = [] - - # Keep track of paragraph and table positions - content_items: list[tuple[int, str, Table | Paragraph]] = [] - - it = iter(doc.element.body) - part = next(it, None) - i = 0 - while part is not None: - parser_docx_part(part, doc, content_items, i) - i = i + 1 - part = next(it, None) - - # Process sorted content - for _, item_type, item in content_items: - if item_type == "paragraph": - if isinstance(item, Table): - continue - text.append(item.text) - elif item_type == "table": - # Process tables - if not isinstance(item, Table): - continue - try: - # Check if any cell in the table has text - has_content = False - for row in item.rows: - if any(cell.text.strip() for cell in row.cells): - has_content = True - break - - if has_content: - cell_texts = [cell.text.replace("\n", "
") for cell in item.rows[0].cells] - markdown_table = f"| {' | '.join(cell_texts)} |\n" - markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n" - - for row in item.rows[1:]: - # Replace newlines with
in each cell - row_cells = [cell.text.replace("\n", "
") for cell in row.cells] - markdown_table += "| " + " | ".join(row_cells) + " |\n" - - text.append(markdown_table) - except Exception as e: - logger.warning("Failed to extract table from DOC: %s", e) - continue - - return "\n".join(text) - - except Exception as e: - raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e - - -def _download_file_content(http_client: HttpClientProtocol, file: File) -> bytes: - """Download the content of a file based on its transfer method.""" - try: - if file.transfer_method == FileTransferMethod.REMOTE_URL: - if file.remote_url is None: - raise FileDownloadError("Missing URL for remote file") - response = http_client.get(file.remote_url) - response.raise_for_status() - return response.content - else: - return file_manager.download(file) - except Exception as e: - raise FileDownloadError(f"Error downloading file: {str(e)}") from e - - -def _extract_text_from_file( - http_client: HttpClientProtocol, file: File, *, unstructured_api_config: UnstructuredApiConfig -) -> str: - file_content = _download_file_content(http_client, file) - if file.extension: - extracted_text = _extract_text_by_file_extension( - file_content=file_content, - file_extension=file.extension, - unstructured_api_config=unstructured_api_config, - ) - elif file.mime_type: - extracted_text = _extract_text_by_mime_type( - file_content=file_content, - mime_type=file.mime_type, - unstructured_api_config=unstructured_api_config, - ) - else: - raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing") - return extracted_text - - -def _extract_text_from_csv(file_content: bytes) -> str: - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - try: - csv_file = io.StringIO(file_content.decode(encoding, errors="ignore")) - except (UnicodeDecodeError, LookupError): - # If decoding fails, try with utf-8 as last resort - csv_file = io.StringIO(file_content.decode("utf-8", errors="ignore")) - - csv_reader = csv.reader(csv_file) - rows = list(csv_reader) - - if not rows: - return "" - - # Combine multi-line text in the header row - header_row = [cell.replace("\n", " ").replace("\r", "") for cell in rows[0]] - - # Create Markdown table - markdown_table = "| " + " | ".join(header_row) + " |\n" - markdown_table += "| " + " | ".join(["-" * len(col) for col in rows[0]]) + " |\n" - - # Process each data row and combine multi-line text in each cell - for row in rows[1:]: - processed_row = [cell.replace("\n", " ").replace("\r", "") for cell in row] - markdown_table += "| " + " | ".join(processed_row) + " |\n" - - return markdown_table - except Exception as e: - raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e - - -def _extract_text_from_excel(file_content: bytes) -> str: - """Extract text from an Excel file using pandas.""" - - def _construct_markdown_table(df: pd.DataFrame) -> str: - """Manually construct a Markdown table from a DataFrame.""" - # Construct the header row - header_row = "| " + " | ".join(df.columns) + " |" - - # Construct the separator row - separator_row = "| " + " | ".join(["-" * len(col) for col in df.columns]) + " |" - - # Construct the data rows - data_rows = [] - for _, row in df.iterrows(): - data_row = "| " + " | ".join(map(str, row)) + " |" - data_rows.append(data_row) - - # Combine all rows into a single string - markdown_table = "\n".join([header_row, separator_row] + data_rows) - return markdown_table - - try: - excel_file = pd.ExcelFile(io.BytesIO(file_content)) - markdown_table = "" - for sheet_name in excel_file.sheet_names: - try: - df = excel_file.parse(sheet_name=sheet_name) - df.dropna(how="all", inplace=True) - - # Combine multi-line text in each cell into a single line - df = df.map(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) - - # Combine multi-line text in column names into a single line - df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns]) - - # Manually construct the Markdown table - markdown_table += _construct_markdown_table(df) + "\n\n" - except Exception: - continue - return markdown_table - except Exception as e: - raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e - - -def _extract_text_from_ppt(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - from unstructured.partition.api import partition_via_api - from unstructured.partition.ppt import partition_ppt - - api_key = unstructured_api_config.api_key or "" - - try: - if unstructured_api_config.api_url: - with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - else: - with io.BytesIO(file_content) as file: - elements = partition_ppt(file=file) - return "\n".join([getattr(element, "text", "") for element in elements]) - - except Exception as e: - raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e - - -def _extract_text_from_pptx(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - from unstructured.partition.api import partition_via_api - from unstructured.partition.pptx import partition_pptx - - api_key = unstructured_api_config.api_key or "" - - try: - if unstructured_api_config.api_url: - with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - else: - with io.BytesIO(file_content) as file: - elements = partition_pptx(file=file) - return "\n".join([getattr(element, "text", "") for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e - - -def _extract_text_from_epub(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - from unstructured.partition.api import partition_via_api - from unstructured.partition.epub import partition_epub - - api_key = unstructured_api_config.api_key or "" - - try: - if unstructured_api_config.api_url: - with tempfile.NamedTemporaryFile(suffix=".epub", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - else: - pypandoc.download_pandoc() - with io.BytesIO(file_content) as file: - elements = partition_epub(file=file) - return "\n".join([str(element) for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from EPUB: {str(e)}") from e - - -def _extract_text_from_eml(file_content: bytes) -> str: - from unstructured.partition.email import partition_email - - try: - with io.BytesIO(file_content) as file: - elements = partition_email(file=file) - return "\n".join([str(element) for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from EML: {str(e)}") from e - - -def _extract_text_from_msg(file_content: bytes) -> str: - from unstructured.partition.msg import partition_msg - - try: - with io.BytesIO(file_content) as file: - elements = partition_msg(file=file) - return "\n".join([str(element) for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e - - -def _extract_text_from_vtt(vtt_bytes: bytes) -> str: - text = _extract_text_from_plain_text(vtt_bytes) - - # remove bom - text = text.lstrip("\ufeff") - - raw_results = [] - for caption in webvtt.from_string(text): - raw_results.append((caption.voice, caption.text)) - - # Merge consecutive utterances by the same speaker - merged_results = [] - if raw_results: - current_speaker, current_text = raw_results[0] - - for i in range(1, len(raw_results)): - spk, txt = raw_results[i] - if spk is None: - merged_results.append((None, current_text)) - continue - - if spk == current_speaker: - # If it is the same speaker, merge the utterances (joined by space) - current_text += " " + txt - else: - # If the speaker changes, register the utterance so far and move on - merged_results.append((current_speaker, current_text)) - current_speaker, current_text = spk, txt - - # Add the last element - merged_results.append((current_speaker, current_text)) - else: - merged_results = raw_results - - # Return the result in the specified format: Speaker "text" style - formatted = [f'{spk or ""} "{txt}"' for spk, txt in merged_results] - return "\n".join(formatted) - - -def _extract_text_from_properties(file_content: bytes) -> str: - try: - text = _extract_text_from_plain_text(file_content) - lines = text.splitlines() - result = [] - for line in lines: - line = line.strip() - # Preserve comments and empty lines - if not line or line.startswith("#") or line.startswith("!"): - result.append(line) - continue - - if "=" in line: - key, value = line.split("=", 1) - elif ":" in line: - key, value = line.split(":", 1) - else: - key, value = line, "" - - result.append(f"{key.strip()}: {value.strip()}") - - return "\n".join(result) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from properties file: {str(e)}") from e diff --git a/api/graphon/nodes/end/__init__.py b/api/graphon/nodes/end/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/nodes/end/end_node.py b/api/graphon/nodes/end/end_node.py deleted file mode 100644 index 11b9e58644..0000000000 --- a/api/graphon/nodes/end/end_node.py +++ /dev/null @@ -1,47 +0,0 @@ -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.base.template import Template -from graphon.nodes.end.entities import EndNodeData - - -class EndNode(Node[EndNodeData]): - node_type = BuiltinNodeTypes.END - execution_type = NodeExecutionType.RESPONSE - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run node - collect all outputs at once. - - This method runs after streaming is complete (if streaming was enabled). - It collects all output variables and returns them. - """ - output_variables = self.node_data.outputs - - outputs = {} - for variable_selector in output_variables: - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - value = variable.to_object() if variable is not None else None - outputs[variable_selector.variable] = value - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=outputs, - outputs=outputs, - ) - - def get_streaming_template(self) -> Template: - """ - Get the template for streaming. - - Returns: - Template instance for this End node - """ - outputs_config = [ - {"variable": output.variable, "value_selector": output.value_selector} for output in self.node_data.outputs - ] - return Template.from_end_outputs(outputs_config) diff --git a/api/graphon/nodes/end/entities.py b/api/graphon/nodes/end/entities.py deleted file mode 100644 index 839aed7e4b..0000000000 --- a/api/graphon/nodes/end/entities.py +++ /dev/null @@ -1,27 +0,0 @@ -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base.entities import OutputVariableEntity - - -class EndNodeData(BaseNodeData): - """ - END Node Data. - """ - - type: NodeType = BuiltinNodeTypes.END - outputs: list[OutputVariableEntity] - - -class EndStreamParam(BaseModel): - """ - EndStreamParam entity - """ - - end_dependencies: dict[str, list[str]] = Field( - ..., description="end dependencies (end node id -> dependent node ids)" - ) - end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field( - ..., description="end stream variable selector mapping (end node id -> stream variable selectors)" - ) diff --git a/api/graphon/nodes/http_request/__init__.py b/api/graphon/nodes/http_request/__init__.py deleted file mode 100644 index b29099db23..0000000000 --- a/api/graphon/nodes/http_request/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .config import build_http_request_config, resolve_http_request_config -from .entities import ( - HTTP_REQUEST_CONFIG_FILTER_KEY, - BodyData, - HttpRequestNodeAuthorization, - HttpRequestNodeBody, - HttpRequestNodeConfig, - HttpRequestNodeData, -) -from .node import HttpRequestNode - -__all__ = [ - "HTTP_REQUEST_CONFIG_FILTER_KEY", - "BodyData", - "HttpRequestNode", - "HttpRequestNodeAuthorization", - "HttpRequestNodeBody", - "HttpRequestNodeConfig", - "HttpRequestNodeData", - "build_http_request_config", - "resolve_http_request_config", -] diff --git a/api/graphon/nodes/http_request/config.py b/api/graphon/nodes/http_request/config.py deleted file mode 100644 index 53bf6c7ae4..0000000000 --- a/api/graphon/nodes/http_request/config.py +++ /dev/null @@ -1,33 +0,0 @@ -from collections.abc import Mapping - -from .entities import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNodeConfig - - -def build_http_request_config( - *, - max_connect_timeout: int = 10, - max_read_timeout: int = 600, - max_write_timeout: int = 600, - max_binary_size: int = 10 * 1024 * 1024, - max_text_size: int = 1 * 1024 * 1024, - ssl_verify: bool = True, - ssrf_default_max_retries: int = 3, -) -> HttpRequestNodeConfig: - return HttpRequestNodeConfig( - max_connect_timeout=max_connect_timeout, - max_read_timeout=max_read_timeout, - max_write_timeout=max_write_timeout, - max_binary_size=max_binary_size, - max_text_size=max_text_size, - ssl_verify=ssl_verify, - ssrf_default_max_retries=ssrf_default_max_retries, - ) - - -def resolve_http_request_config(filters: Mapping[str, object] | None) -> HttpRequestNodeConfig: - if not filters: - raise ValueError("http_request_config is required to build HTTP request default config") - config = filters.get(HTTP_REQUEST_CONFIG_FILTER_KEY) - if not isinstance(config, HttpRequestNodeConfig): - raise ValueError("http_request_config must be an HttpRequestNodeConfig instance") - return config diff --git a/api/graphon/nodes/http_request/entities.py b/api/graphon/nodes/http_request/entities.py deleted file mode 100644 index 6fa067bdd1..0000000000 --- a/api/graphon/nodes/http_request/entities.py +++ /dev/null @@ -1,241 +0,0 @@ -import mimetypes -from collections.abc import Sequence -from dataclasses import dataclass -from email.message import Message -from typing import Any, Literal - -import charset_normalizer -import httpx -from pydantic import BaseModel, Field, ValidationInfo, field_validator - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - -HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config" - - -class HttpRequestNodeAuthorizationConfig(BaseModel): - type: Literal["basic", "bearer", "custom"] - api_key: str - header: str = "" - - -class HttpRequestNodeAuthorization(BaseModel): - type: Literal["no-auth", "api-key"] - config: HttpRequestNodeAuthorizationConfig | None = None - - @field_validator("config", mode="before") - @classmethod - def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo): - """ - Check config, if type is no-auth, config should be None, otherwise it should be a dict. - """ - if values.data["type"] == "no-auth": - return None - else: - if not v or not isinstance(v, dict): - raise ValueError("config should be a dict") - - return v - - -class BodyData(BaseModel): - key: str = "" - type: Literal["file", "text"] - value: str = "" - file: Sequence[str] = Field(default_factory=list) - - -class HttpRequestNodeBody(BaseModel): - type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json", "binary"] - data: Sequence[BodyData] = Field(default_factory=list) - - @field_validator("data", mode="before") - @classmethod - def check_data(cls, v: Any): - """For compatibility, if body is not set, return empty list.""" - if not v: - return [] - if isinstance(v, str): - return [BodyData(key="", type="text", value=v)] - return v - - -class HttpRequestNodeTimeout(BaseModel): - connect: int | None = None - read: int | None = None - write: int | None = None - - -@dataclass(frozen=True, slots=True) -class HttpRequestNodeConfig: - max_connect_timeout: int - max_read_timeout: int - max_write_timeout: int - max_binary_size: int - max_text_size: int - ssl_verify: bool - ssrf_default_max_retries: int - - def default_timeout(self) -> "HttpRequestNodeTimeout": - return HttpRequestNodeTimeout( - connect=self.max_connect_timeout, - read=self.max_read_timeout, - write=self.max_write_timeout, - ) - - -class HttpRequestNodeData(BaseNodeData): - """ - Code Node Data. - """ - - type: NodeType = BuiltinNodeTypes.HTTP_REQUEST - method: Literal[ - "get", - "post", - "put", - "patch", - "delete", - "head", - "options", - "GET", - "POST", - "PUT", - "PATCH", - "DELETE", - "HEAD", - "OPTIONS", - ] - url: str - authorization: HttpRequestNodeAuthorization - headers: str - params: str - body: HttpRequestNodeBody | None = None - timeout: HttpRequestNodeTimeout | None = None - ssl_verify: bool | None = None - - -class Response: - headers: dict[str, str] - response: httpx.Response - _cached_text: str | None - - def __init__(self, response: httpx.Response): - self.response = response - self.headers = dict(response.headers) - self._cached_text = None - - @property - def is_file(self): - """ - Determine if the response contains a file by checking: - 1. Content-Disposition header (RFC 6266) - 2. Content characteristics - 3. MIME type analysis - """ - content_type = self.content_type.split(";")[0].strip().lower() - parsed_content_disposition = self.parsed_content_disposition - - # Check if it's explicitly marked as an attachment - if parsed_content_disposition: - disp_type = parsed_content_disposition.get_content_disposition() # Returns 'attachment', 'inline', or None - filename = parsed_content_disposition.get_filename() # Returns filename if present, None otherwise - if disp_type == "attachment" or filename is not None: - return True - - # For 'text/' types, only 'csv' should be downloaded as file - if content_type.startswith("text/") and "csv" not in content_type: - return False - - # For application types, try to detect if it's a text-based format - if content_type.startswith("application/"): - # Common text-based application types - if any( - text_type in content_type - for text_type in ("json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql") - ): - return False - - # Try to detect if content is text-based by sampling first few bytes - try: - # Sample first 1024 bytes for text detection - content_sample = self.response.content[:1024] - content_sample.decode("utf-8") - # If we can decode as UTF-8 and find common text patterns, likely not a file - text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ") - if any(marker in content_sample for marker in text_markers): - return False - except UnicodeDecodeError: - # If we can't decode as UTF-8, likely a binary file - return True - - # For other types, use MIME type analysis - main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or "")) - if main_type: - return main_type.split("/")[0] in ("application", "image", "audio", "video") - - # For unknown types, check if it's a media type - return any(media_type in content_type for media_type in ("image/", "audio/", "video/")) - - @property - def content_type(self) -> str: - return self.headers.get("content-type", "") - - @property - def text(self) -> str: - """ - Get response text with robust encoding detection. - - Uses charset_normalizer for better encoding detection than httpx's default, - which helps handle Chinese and other non-ASCII characters properly. - """ - # Check cache first - if hasattr(self, "_cached_text") and self._cached_text is not None: - return self._cached_text - - # Try charset_normalizer for robust encoding detection first - detected_encoding = charset_normalizer.from_bytes(self.response.content).best() - if detected_encoding and detected_encoding.encoding: - try: - text = self.response.content.decode(detected_encoding.encoding) - self._cached_text = text - return text - except (UnicodeDecodeError, TypeError, LookupError): - # Fallback to httpx's encoding detection if charset_normalizer fails - pass - - # Fallback to httpx's built-in encoding detection - text = self.response.text - self._cached_text = text - return text - - @property - def content(self) -> bytes: - return self.response.content - - @property - def status_code(self) -> int: - return self.response.status_code - - @property - def size(self) -> int: - return len(self.content) - - @property - def readable_size(self) -> str: - if self.size < 1024: - return f"{self.size} bytes" - elif self.size < 1024 * 1024: - return f"{(self.size / 1024):.2f} KB" - else: - return f"{(self.size / 1024 / 1024):.2f} MB" - - @property - def parsed_content_disposition(self) -> Message | None: - content_disposition = self.headers.get("content-disposition", "") - if content_disposition: - msg = Message() - msg["content-disposition"] = content_disposition - return msg - return None diff --git a/api/graphon/nodes/http_request/exc.py b/api/graphon/nodes/http_request/exc.py deleted file mode 100644 index 46613c9e86..0000000000 --- a/api/graphon/nodes/http_request/exc.py +++ /dev/null @@ -1,26 +0,0 @@ -class HttpRequestNodeError(ValueError): - """Custom error for HTTP request node.""" - - -class AuthorizationConfigError(HttpRequestNodeError): - """Raised when authorization config is missing or invalid.""" - - -class FileFetchError(HttpRequestNodeError): - """Raised when a file cannot be fetched.""" - - -class InvalidHttpMethodError(HttpRequestNodeError): - """Raised when an invalid HTTP method is used.""" - - -class ResponseSizeError(HttpRequestNodeError): - """Raised when the response size exceeds the allowed threshold.""" - - -class RequestBodyError(HttpRequestNodeError): - """Raised when the request body is invalid.""" - - -class InvalidURLError(HttpRequestNodeError): - """Raised when the URL is invalid.""" diff --git a/api/graphon/nodes/http_request/executor.py b/api/graphon/nodes/http_request/executor.py deleted file mode 100644 index 0c6f4ecd3a..0000000000 --- a/api/graphon/nodes/http_request/executor.py +++ /dev/null @@ -1,488 +0,0 @@ -import base64 -import json -import secrets -import string -from collections.abc import Callable, Mapping -from copy import deepcopy -from typing import Any, Literal -from urllib.parse import urlencode, urlparse - -import httpx -from json_repair import repair_json - -from graphon.file.enums import FileTransferMethod -from graphon.runtime import VariablePool -from graphon.variables.segments import ArrayFileSegment, FileSegment - -from ..protocols import FileManagerProtocol, HttpClientProtocol -from .entities import ( - HttpRequestNodeAuthorization, - HttpRequestNodeConfig, - HttpRequestNodeData, - HttpRequestNodeTimeout, - Response, -) -from .exc import ( - AuthorizationConfigError, - FileFetchError, - HttpRequestNodeError, - InvalidHttpMethodError, - InvalidURLError, - RequestBodyError, - ResponseSizeError, -) - -BODY_TYPE_TO_CONTENT_TYPE = { - "json": "application/json", - "x-www-form-urlencoded": "application/x-www-form-urlencoded", - "form-data": "multipart/form-data", - "raw-text": "text/plain", -} - - -class Executor: - method: Literal[ - "get", - "head", - "post", - "put", - "delete", - "patch", - "options", - "GET", - "POST", - "PUT", - "PATCH", - "DELETE", - "HEAD", - "OPTIONS", - ] - url: str - params: list[tuple[str, str]] | None - content: str | bytes | None - data: Mapping[str, Any] | None - files: list[tuple[str, tuple[str | None, bytes, str]]] | None - json: Any - headers: dict[str, str] - auth: HttpRequestNodeAuthorization - timeout: HttpRequestNodeTimeout - max_retries: int - - boundary: str - - def __init__( - self, - *, - node_data: HttpRequestNodeData, - timeout: HttpRequestNodeTimeout, - variable_pool: VariablePool, - http_request_config: HttpRequestNodeConfig, - max_retries: int | None = None, - ssl_verify: bool | None = None, - http_client: HttpClientProtocol, - file_manager: FileManagerProtocol, - ): - self._http_request_config = http_request_config - # If authorization API key is present, convert the API key using the variable pool - if node_data.authorization.type == "api-key": - if node_data.authorization.config is None: - raise AuthorizationConfigError("authorization config is required") - node_data.authorization.config.api_key = variable_pool.convert_template( - node_data.authorization.config.api_key - ).text - # Validate that API key is not empty after template conversion - if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip(): - raise AuthorizationConfigError( - "API key is required for authorization but was empty. Please provide a valid API key." - ) - - self.url = node_data.url - self.method = node_data.method - self.auth = node_data.authorization - self.timeout = timeout - self.ssl_verify = ssl_verify if ssl_verify is not None else node_data.ssl_verify - if self.ssl_verify is None: - self.ssl_verify = self._http_request_config.ssl_verify - if not isinstance(self.ssl_verify, bool): - raise ValueError("ssl_verify must be a boolean") - self.params = None - self.headers = {} - self.content = None - self.files = None - self.data = None - self.json = None - self.max_retries = ( - max_retries if max_retries is not None else self._http_request_config.ssrf_default_max_retries - ) - self._http_client = http_client - self._file_manager = file_manager - - # init template - self.variable_pool = variable_pool - self.node_data = node_data - self._initialize() - - def _initialize(self): - self._init_url() - self._init_params() - self._init_headers() - self._init_body() - - def _init_url(self): - self.url = self.variable_pool.convert_template(self.node_data.url).text - - # check if url is a valid URL - if not self.url: - raise InvalidURLError("url is required") - if not self.url.startswith(("http://", "https://")): - raise InvalidURLError("url should start with http:// or https://") - - def _init_params(self): - """ - Almost same as _init_headers(), difference: - 1. response a list tuple to support same key, like 'aa=1&aa=2' - 2. param value may have '\n', we need to splitlines then extract the variable value. - """ - result = [] - for line in self.node_data.params.splitlines(): - if not (line := line.strip()): - continue - - key, *value = line.split(":", 1) - if not (key := key.strip()): - continue - - value_str = value[0].strip() if value else "" - result.append( - (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text) - ) - - if result: - self.params = result - - def _init_headers(self): - """ - Convert the header string of frontend to a dictionary. - - Each line in the header string represents a key-value pair. - Keys and values are separated by ':'. - Empty values are allowed. - - Examples: - 'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'} - 'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'} - 'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'} - - """ - headers = self.variable_pool.convert_template(self.node_data.headers).text - self.headers = { - key.strip(): (value[0].strip() if value else "") - for line in headers.splitlines() - if line.strip() - for key, *value in [line.split(":", 1)] - } - - def _init_body(self): - body = self.node_data.body - if body is not None: - data = body.data - match body.type: - case "none": - self.content = "" - case "raw-text": - if len(data) != 1: - raise RequestBodyError("raw-text body type should have exactly one item") - self.content = self.variable_pool.convert_template(data[0].value).text - case "json": - if len(data) != 1: - raise RequestBodyError("json body type should have exactly one item") - json_string = self.variable_pool.convert_template(data[0].value).text - try: - repaired = repair_json(json_string) - json_object = json.loads(repaired, strict=False) - except json.JSONDecodeError as e: - raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e - self.json = json_object - # self.json = self._parse_object_contains_variables(json_object) - case "binary": - if len(data) != 1: - raise RequestBodyError("binary body type should have exactly one item") - file_selector = data[0].file - file_variable = self.variable_pool.get_file(file_selector) - if file_variable is None: - raise FileFetchError(f"cannot fetch file with selector {file_selector}") - file = file_variable.value - self.content = self._file_manager.download(file) - case "x-www-form-urlencoded": - form_data = { - self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( - item.value - ).text - for item in data - } - self.data = form_data - case "form-data": - form_data = { - self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( - item.value - ).text - for item in filter(lambda item: item.type == "text", data) - } - file_selectors = { - self.variable_pool.convert_template(item.key).text: item.file - for item in filter(lambda item: item.type == "file", data) - } - - # get files from file_selectors, add support for array file variables - files_list = [] - for key, selector in file_selectors.items(): - segment = self.variable_pool.get(selector) - if isinstance(segment, FileSegment): - files_list.append((key, [segment.value])) - elif isinstance(segment, ArrayFileSegment): - files_list.append((key, list(segment.value))) - - # get files from file_manager - files: dict[str, list[tuple[str | None, bytes, str]]] = {} - for key, files_in_segment in files_list: - for file in files_in_segment: - if file.reference is not None or ( - file.transfer_method == FileTransferMethod.REMOTE_URL and file.remote_url is not None - ): - file_tuple = ( - file.filename, - self._file_manager.download(file), - file.mime_type or "application/octet-stream", - ) - if key not in files: - files[key] = [] - files[key].append(file_tuple) - - # convert files to list for httpx request - # If there are no actual files, we still need to force httpx to use `multipart/form-data`. - # This is achieved by inserting a harmless placeholder file that will be ignored by the server. - if not files: - self.files = [("__multipart_placeholder__", ("", b"", "application/octet-stream"))] - if files: - self.files = [] - for key, file_tuples in files.items(): - for file_tuple in file_tuples: - self.files.append((key, file_tuple)) - - self.data = form_data - - def _assembling_headers(self) -> dict[str, Any]: - authorization = deepcopy(self.auth) - headers = deepcopy(self.headers) or {} - if self.auth.type == "api-key": - if self.auth.config is None: - raise AuthorizationConfigError("self.authorization config is required") - if authorization.config is None: - raise AuthorizationConfigError("authorization config is required") - - if not authorization.config.header: - authorization.config.header = "Authorization" - - if self.auth.config.type == "bearer" and authorization.config.api_key: - headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" - elif self.auth.config.type == "basic" and authorization.config.api_key: - credentials = authorization.config.api_key - if ":" in credentials: - encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8") - else: - encoded_credentials = credentials - headers[authorization.config.header] = f"Basic {encoded_credentials}" - elif self.auth.config.type == "custom": - if authorization.config.header and authorization.config.api_key: - headers[authorization.config.header] = authorization.config.api_key - - # Handle Content-Type for multipart/form-data requests - # Fix for issue #23829: Missing boundary when using multipart/form-data - body = self.node_data.body - if body and body.type == "form-data": - # For multipart/form-data with files (including placeholder files), - # remove any manually set Content-Type header to let httpx handle - # For multipart/form-data, if any files are present (including placeholder files), - # we must remove any manually set Content-Type header. This is because httpx needs to - # automatically set the Content-Type and boundary for multipart encoding whenever files - # are included, even if they are placeholders, to avoid boundary issues and ensure correct - # file upload behaviour. Manually setting Content-Type can cause httpx to fail to set the - # boundary, resulting in invalid requests. - if self.files: - # Remove Content-Type if it was manually set to avoid boundary issues - headers = {k: v for k, v in headers.items() if k.lower() != "content-type"} - else: - # No files at all, set Content-Type manually - if "content-type" not in (k.lower() for k in headers): - headers["Content-Type"] = "multipart/form-data" - elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE: - # Set Content-Type for other body types - if "content-type" not in (k.lower() for k in headers): - headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] - - return headers - - def _validate_and_parse_response(self, response: httpx.Response) -> Response: - executor_response = Response(response) - - threshold_size = ( - self._http_request_config.max_binary_size - if executor_response.is_file - else self._http_request_config.max_text_size - ) - if executor_response.size > threshold_size: - raise ResponseSizeError( - f"{'File' if executor_response.is_file else 'Text'} size is too large," - f" max size is {threshold_size / 1024 / 1024:.2f} MB," - f" but current size is {executor_response.readable_size}." - ) - - return executor_response - - def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: - """ - do http request depending on api bundle - """ - _METHOD_MAP: dict[str, Callable[..., httpx.Response]] = { - "get": self._http_client.get, - "head": self._http_client.head, - "post": self._http_client.post, - "put": self._http_client.put, - "delete": self._http_client.delete, - "patch": self._http_client.patch, - } - method_lc = self.method.lower() - if method_lc not in _METHOD_MAP: - raise InvalidHttpMethodError(f"Invalid http method {self.method}") - - request_args: dict[str, Any] = { - "data": self.data, - "files": self.files, - "json": self.json, - "content": self.content, - "headers": headers, - "params": self.params, - "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), - "ssl_verify": self.ssl_verify, - "follow_redirects": True, - } - # request_args = {k: v for k, v in request_args.items() if v is not None} - try: - response = _METHOD_MAP[method_lc]( - url=self.url, - **request_args, - max_retries=self.max_retries, - ) - except self._http_client.max_retries_exceeded_error as e: - raise HttpRequestNodeError(f"Reached maximum retries for URL {self.url}") from e - except self._http_client.request_error as e: - raise HttpRequestNodeError(str(e)) from e - return response - - def invoke(self) -> Response: - # assemble headers - headers = self._assembling_headers() - # do http request - response = self._do_http_request(headers) - # validate response - return self._validate_and_parse_response(response) - - def to_log(self): - url_parts = urlparse(self.url) - path = url_parts.path or "/" - - # Add query parameters - if self.params: - query_string = urlencode(self.params) - path += f"?{query_string}" - elif url_parts.query: - path += f"?{url_parts.query}" - - raw = f"{self.method.upper()} {path} HTTP/1.1\r\n" - raw += f"Host: {url_parts.netloc}\r\n" - - headers = self._assembling_headers() - body = self.node_data.body - boundary = f"----WebKitFormBoundary{_generate_random_string(16)}" - if body: - if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE: - headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] - if body.type == "form-data": - headers["Content-Type"] = f"multipart/form-data; boundary={boundary}" - for k, v in headers.items(): - if self.auth.type == "api-key": - authorization_header = "Authorization" - if self.auth.config and self.auth.config.header: - authorization_header = self.auth.config.header - if k.lower() == authorization_header.lower(): - raw += f"{k}: {'*' * len(v)}\r\n" - continue - raw += f"{k}: {v}\r\n" - - body_string = "" - # Only log actual files if present. - # '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file. - # This prevents logging meaningless placeholder entries. - if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files): - for file_entry in self.files: - # file_entry should be (key, (filename, content, mime_type)), but handle edge cases - if len(file_entry) != 2 or len(file_entry[1]) < 2: - continue # skip malformed entries - key = file_entry[0] - content = file_entry[1][1] - body_string += f"--{boundary}\r\n" - body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - # decode content safely - # Do not decode binary content; use a placeholder with file metadata instead. - # Includes filename, size, and MIME type for better logging context. - body_string += ( - f"\r\n" - ) - body_string += f"--{boundary}--\r\n" - elif self.node_data.body: - if self.content: - # If content is bytes, do not decode it; show a placeholder with size. - # Provides content size information for binary data without exposing the raw bytes. - if isinstance(self.content, bytes): - body_string = f"" - else: - body_string = self.content - elif self.data and self.node_data.body.type == "x-www-form-urlencoded": - body_string = urlencode(self.data) - elif self.data and self.node_data.body.type == "form-data": - for key, value in self.data.items(): - body_string += f"--{boundary}\r\n" - body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - body_string += f"{value}\r\n" - body_string += f"--{boundary}--\r\n" - elif self.json: - body_string = json.dumps(self.json) - elif self.node_data.body.type == "raw-text": - if len(self.node_data.body.data) != 1: - raise RequestBodyError("raw-text body type should have exactly one item") - body_string = self.node_data.body.data[0].value - if body_string: - raw += f"Content-Length: {len(body_string)}\r\n" - raw += "\r\n" # Empty line between headers and body - raw += body_string - - return raw - - -def _generate_random_string(n: int) -> str: - """ - Generate a random string of lowercase ASCII letters. - - Args: - n (int): The length of the random string to generate. - - Returns: - str: A random string of lowercase ASCII letters with length n. - - Example: - >>> _generate_random_string(5) - 'abcde' - """ - return "".join(secrets.choice(string.ascii_lowercase) for _ in range(n)) diff --git a/api/graphon/nodes/http_request/node.py b/api/graphon/nodes/http_request/node.py deleted file mode 100644 index 3d74347a7f..0000000000 --- a/api/graphon/nodes/http_request/node.py +++ /dev/null @@ -1,261 +0,0 @@ -import logging -import mimetypes -from collections.abc import Callable, Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod -from graphon.node_events import NodeRunResult -from graphon.nodes.base import variable_template_parser -from graphon.nodes.base.entities import VariableSelector -from graphon.nodes.base.node import Node -from graphon.nodes.http_request.executor import Executor -from graphon.nodes.protocols import ( - FileManagerProtocol, - FileReferenceFactoryProtocol, - HttpClientProtocol, - ToolFileManagerProtocol, -) -from graphon.variables.segments import ArrayFileSegment - -from .config import build_http_request_config, resolve_http_request_config -from .entities import ( - HTTP_REQUEST_CONFIG_FILTER_KEY, - HttpRequestNodeConfig, - HttpRequestNodeData, - HttpRequestNodeTimeout, - Response, -) -from .exc import HttpRequestNodeError, RequestBodyError - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - - -class HttpRequestNode(Node[HttpRequestNodeData]): - node_type = BuiltinNodeTypes.HTTP_REQUEST - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - http_request_config: HttpRequestNodeConfig, - http_client: HttpClientProtocol, - tool_file_manager_factory: Callable[[], ToolFileManagerProtocol], - file_manager: FileManagerProtocol, - file_reference_factory: FileReferenceFactoryProtocol, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - self._http_request_config = http_request_config - self._http_client = http_client - self._tool_file_manager_factory = tool_file_manager_factory - self._file_manager = file_manager - self._file_reference_factory = file_reference_factory - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - if not filters or HTTP_REQUEST_CONFIG_FILTER_KEY not in filters: - http_request_config = build_http_request_config() - else: - http_request_config = resolve_http_request_config(filters) - default_timeout = http_request_config.default_timeout() - return { - "type": "http-request", - "config": { - "method": "get", - "authorization": { - "type": "no-auth", - }, - "body": {"type": "none"}, - "timeout": { - **default_timeout.model_dump(), - "max_connect_timeout": http_request_config.max_connect_timeout, - "max_read_timeout": http_request_config.max_read_timeout, - "max_write_timeout": http_request_config.max_write_timeout, - }, - "ssl_verify": http_request_config.ssl_verify, - }, - "retry_config": { - "max_retries": http_request_config.ssrf_default_max_retries, - "retry_interval": 0.5 * (2**2), - "retry_enabled": True, - }, - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - process_data = {} - try: - http_executor = Executor( - node_data=self.node_data, - timeout=self._get_request_timeout(self.node_data), - variable_pool=self.graph_runtime_state.variable_pool, - http_request_config=self._http_request_config, - # Must be 0 to disable executor-level retries, as the graph engine handles them. - # This is critical to prevent nested retries. - max_retries=0, - ssl_verify=self.node_data.ssl_verify, - http_client=self._http_client, - file_manager=self._file_manager, - ) - process_data["request"] = http_executor.to_log() - - response = http_executor.invoke() - files = self.extract_files(url=http_executor.url, response=response) - if not response.response.is_success and (self.error_strategy or self.retry): - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - outputs={ - "status_code": response.status_code, - "body": response.text if not files.value else "", - "headers": response.headers, - "files": files, - }, - process_data={ - "request": http_executor.to_log(), - }, - error=f"Request failed with status code {response.status_code}", - error_type="HTTPResponseCodeError", - ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "status_code": response.status_code, - "body": response.text if not files.value else "", - "headers": response.headers, - "files": files, - }, - process_data={ - "request": http_executor.to_log(), - }, - ) - except HttpRequestNodeError as e: - logger.warning("http request node %s failed to run: %s", self._node_id, e) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - process_data=process_data, - error_type=type(e).__name__, - ) - - def _get_request_timeout(self, node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: - default_timeout = self._http_request_config.default_timeout() - timeout = node_data.timeout - if timeout is None: - return default_timeout - - return HttpRequestNodeTimeout( - connect=timeout.connect or default_timeout.connect, - read=timeout.read or default_timeout.read, - write=timeout.write or default_timeout.write, - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: HttpRequestNodeData, - ) -> Mapping[str, Sequence[str]]: - selectors: list[VariableSelector] = [] - selectors += variable_template_parser.extract_selectors_from_template(node_data.url) - selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) - selectors += variable_template_parser.extract_selectors_from_template(node_data.params) - if node_data.body: - body_type = node_data.body.type - data = node_data.body.data - match body_type: - case "none": - pass - case "binary": - if len(data) != 1: - raise RequestBodyError("invalid body data, should have only one item") - selector = data[0].file - selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector)) - case "json" | "raw-text": - if len(data) != 1: - raise RequestBodyError("invalid body data, should have only one item") - selectors += variable_template_parser.extract_selectors_from_template(data[0].key) - selectors += variable_template_parser.extract_selectors_from_template(data[0].value) - case "x-www-form-urlencoded": - for item in data: - selectors += variable_template_parser.extract_selectors_from_template(item.key) - selectors += variable_template_parser.extract_selectors_from_template(item.value) - case "form-data": - for item in data: - selectors += variable_template_parser.extract_selectors_from_template(item.key) - if item.type == "text": - selectors += variable_template_parser.extract_selectors_from_template(item.value) - elif item.type == "file": - selectors.append( - VariableSelector(variable="#" + ".".join(item.file) + "#", value_selector=item.file) - ) - - mapping = {} - for selector_iter in selectors: - mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector - - return mapping - - def extract_files(self, url: str, response: Response) -> ArrayFileSegment: - """ - Extract files from response by checking both Content-Type header and URL - """ - files: list[File] = [] - is_file = response.is_file - content_type = response.content_type - content = response.content - parsed_content_disposition = response.parsed_content_disposition - content_disposition_type = None - - if not is_file: - return ArrayFileSegment(value=[]) - - if parsed_content_disposition: - content_disposition_filename = parsed_content_disposition.get_filename() - if content_disposition_filename: - # If filename is available from content-disposition, use it to guess the content type - content_disposition_type = mimetypes.guess_type(content_disposition_filename)[0] - - # Guess file extension from URL or Content-Type header - filename = url.split("?")[0].split("/")[-1] or "" - mime_type = ( - content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" - ) - tool_file_manager = self._tool_file_manager_factory() - - tool_file = tool_file_manager.create_file_by_raw( - file_binary=content, - mimetype=mime_type, - ) - - file = self._file_reference_factory.build_from_mapping( - mapping={ - "tool_file_id": tool_file.id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - ) - files.append(file) - - return ArrayFileSegment(value=files) - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled diff --git a/api/graphon/nodes/human_input/__init__.py b/api/graphon/nodes/human_input/__init__.py deleted file mode 100644 index 1789604577..0000000000 --- a/api/graphon/nodes/human_input/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Human Input node implementation. -""" diff --git a/api/graphon/nodes/human_input/entities.py b/api/graphon/nodes/human_input/entities.py deleted file mode 100644 index aa01bde145..0000000000 --- a/api/graphon/nodes/human_input/entities.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Human Input node entities. - -The graph package owns the workflow-facing form schema and keeps it transportable -across runtimes. Dify-specific delivery surface and recipient translation stay -outside `graphon`. -""" - -import re -from collections.abc import Mapping, Sequence -from datetime import datetime, timedelta -from typing import Any, Self - -from pydantic import BaseModel, Field, field_validator, model_validator - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from graphon.variables.consts import SELECTORS_LENGTH - -from .enums import ButtonStyle, FormInputType, PlaceholderType, TimeoutUnit - -_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") - - -class FormInputDefault(BaseModel): - """Default configuration for form inputs.""" - - # NOTE: Ideally, a discriminated union would be used to model - # FormInputDefault. However, the UI requires preserving the previous - # value when switching between `VARIABLE` and `CONSTANT` types. This - # necessitates retaining all fields, making a discriminated union unsuitable. - - type: PlaceholderType - - # The selector of default variable, used when `type` is `VARIABLE`. - selector: Sequence[str] = Field(default_factory=tuple) # - - # The value of the default, used when `type` is `CONSTANT`. - # TODO: How should we express JSON values? - value: str = "" - - @model_validator(mode="after") - def _validate_selector(self) -> Self: - if self.type == PlaceholderType.CONSTANT: - return self - if len(self.selector) < SELECTORS_LENGTH: - raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}") - return self - - -class FormInput(BaseModel): - """Form input definition.""" - - type: FormInputType - output_variable_name: str - default: FormInputDefault | None = None - - -_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") - - -class UserAction(BaseModel): - """User action configuration.""" - - # id is the identifier for this action. - # It also serves as the identifiers of output handle. - # - # The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.) - id: str = Field(max_length=20) - title: str = Field(max_length=20) - button_style: ButtonStyle = ButtonStyle.DEFAULT - - @field_validator("id") - @classmethod - def _validate_id(cls, value: str) -> str: - if not _IDENTIFIER_PATTERN.match(value): - raise ValueError( - f"'{value}' is not a valid identifier. It must start with a letter or underscore, " - f"and contain only letters, numbers, or underscores." - ) - return value - - -class HumanInputNodeData(BaseNodeData): - """Human Input node data.""" - - type: NodeType = BuiltinNodeTypes.HUMAN_INPUT - form_content: str = "" - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - timeout: int = 36 - timeout_unit: TimeoutUnit = TimeoutUnit.HOUR - - @field_validator("inputs") - @classmethod - def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: - seen_names: set[str] = set() - for form_input in inputs: - name = form_input.output_variable_name - if name in seen_names: - raise ValueError(f"duplicated output_variable_name '{name}' in inputs") - seen_names.add(name) - return inputs - - @field_validator("user_actions") - @classmethod - def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]: - seen_ids: set[str] = set() - for action in user_actions: - action_id = action.id - if action_id in seen_ids: - raise ValueError(f"duplicated user action id '{action_id}'") - seen_ids.add(action_id) - return user_actions - - def expiration_time(self, start_time: datetime) -> datetime: - if self.timeout_unit == TimeoutUnit.HOUR: - return start_time + timedelta(hours=self.timeout) - elif self.timeout_unit == TimeoutUnit.DAY: - return start_time + timedelta(days=self.timeout) - else: - raise AssertionError("unknown timeout unit.") - - def outputs_field_names(self) -> Sequence[str]: - field_names = [] - for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content): - field_names.append(match.group("field_name")) - return field_names - - def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]: - variable_mappings: dict[str, Sequence[str]] = {} - - def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None: - for selector in selectors: - if len(selector) < SELECTORS_LENGTH: - continue - qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#" - variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH]) - - form_template_parser = VariableTemplateParser(template=self.form_content) - _add_variable_selectors( - [selector.value_selector for selector in form_template_parser.extract_variable_selectors()] - ) - - for input in self.inputs: - default_value = input.default - if default_value is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - default_value_key = ".".join(default_value.selector) - qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#" - variable_mappings[qualified_variable_mapping_key] = default_value.selector - - return variable_mappings - - def find_action_text(self, action_id: str) -> str: - """ - Resolve action display text by id. - """ - for action in self.user_actions: - if action.id == action_id: - return action.title - return action_id - - -class FormDefinition(BaseModel): - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - rendered_content: str - expiration_time: datetime - - # this is used to store the resolved default values - default_values: dict[str, Any] = Field(default_factory=dict) - - # node_title records the title of the HumanInput node. - node_title: str | None = None - - # display_in_ui controls whether the form should be displayed in UI surfaces. - display_in_ui: bool | None = None - - -class HumanInputSubmissionValidationError(ValueError): - pass - - -def validate_human_input_submission( - *, - inputs: Sequence[FormInput], - user_actions: Sequence[UserAction], - selected_action_id: str, - form_data: Mapping[str, Any], -) -> None: - available_actions = {action.id for action in user_actions} - if selected_action_id not in available_actions: - raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}") - - provided_inputs = set(form_data.keys()) - missing_inputs = [ - form_input.output_variable_name - for form_input in inputs - if form_input.output_variable_name not in provided_inputs - ] - - if missing_inputs: - missing_list = ", ".join(missing_inputs) - raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}") diff --git a/api/graphon/nodes/human_input/enums.py b/api/graphon/nodes/human_input/enums.py deleted file mode 100644 index 3fb0ab4499..0000000000 --- a/api/graphon/nodes/human_input/enums.py +++ /dev/null @@ -1,55 +0,0 @@ -import enum - - -class HumanInputFormStatus(enum.StrEnum): - """Status of a human input form.""" - - # Awaiting submission from any recipient. Forms stay in this state until - # submitted or a timeout rule applies. - WAITING = enum.auto() - # Global timeout reached. The workflow run is stopped and will not resume. - # This is distinct from node-level timeout. - EXPIRED = enum.auto() - # Submitted by a recipient; form data is available and execution resumes - # along the selected action edge. - SUBMITTED = enum.auto() - # Node-level timeout reached. The human input node should emit a timeout - # event and the workflow should resume along the timeout edge. - TIMEOUT = enum.auto() - - -class HumanInputFormKind(enum.StrEnum): - """Kind of a human input form.""" - - RUNTIME = enum.auto() # Form created during workflow execution. - DELIVERY_TEST = enum.auto() # Form created for delivery tests. - - -class ButtonStyle(enum.StrEnum): - """Button styles for user actions.""" - - PRIMARY = enum.auto() - DEFAULT = enum.auto() - ACCENT = enum.auto() - GHOST = enum.auto() - - -class TimeoutUnit(enum.StrEnum): - """Timeout unit for form expiration.""" - - HOUR = enum.auto() - DAY = enum.auto() - - -class FormInputType(enum.StrEnum): - """Form input types.""" - - TEXT_INPUT = enum.auto() - PARAGRAPH = enum.auto() - - -class PlaceholderType(enum.StrEnum): - """Default value types for form inputs.""" - - VARIABLE = enum.auto() - CONSTANT = enum.auto() diff --git a/api/graphon/nodes/human_input/human_input_node.py b/api/graphon/nodes/human_input/human_input_node.py deleted file mode 100644 index fe04022877..0000000000 --- a/api/graphon/nodes/human_input/human_input_node.py +++ /dev/null @@ -1,299 +0,0 @@ -import json -import logging -from collections.abc import Generator, Mapping, Sequence -from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, cast - -from graphon.entities.graph_config import NodeConfigDict -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import ( - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - NodeRunResult, - PauseRequestedEvent, -) -from graphon.node_events.base import NodeEventBase -from graphon.node_events.node import StreamCompletedEvent -from graphon.nodes.base.node import Node -from graphon.nodes.runtime import HumanInputFormStateProtocol, HumanInputNodeRuntimeProtocol -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter - -from .entities import HumanInputNodeData -from .enums import HumanInputFormStatus, PlaceholderType - -if TYPE_CHECKING: - from graphon.entities.graph_init_params import GraphInitParams - from graphon.runtime.graph_runtime_state import GraphRuntimeState - - -_SELECTED_BRANCH_KEY = "selected_branch" - - -logger = logging.getLogger(__name__) - - -class HumanInputNode(Node[HumanInputNodeData]): - node_type = BuiltinNodeTypes.HUMAN_INPUT - execution_type = NodeExecutionType.BRANCH - - _BRANCH_SELECTION_KEYS: tuple[str, ...] = ( - "edge_source_handle", - "edgeSourceHandle", - "source_handle", - _SELECTED_BRANCH_KEY, - "selectedBranch", - "branch", - "branch_id", - "branchId", - "handle", - ) - - _node_data: HumanInputNodeData - _OUTPUT_FIELD_ACTION_ID = "__action_id" - _OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content" - _TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout" - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - runtime: HumanInputNodeRuntimeProtocol | None = None, - form_repository: object | None = None, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - resolved_runtime = runtime - if resolved_runtime is None: - raise ValueError("runtime is required") - if form_repository is not None: - with_form_repository = getattr(resolved_runtime, "with_form_repository", None) - if callable(with_form_repository): - resolved_runtime = cast(HumanInputNodeRuntimeProtocol, with_form_repository(form_repository)) - self._runtime: HumanInputNodeRuntimeProtocol = resolved_runtime - - @classmethod - def version(cls) -> str: - return "1" - - def _resolve_branch_selection(self) -> str | None: - """Determine the branch handle selected by human input if available.""" - - variable_pool = self.graph_runtime_state.variable_pool - - for key in self._BRANCH_SELECTION_KEYS: - handle = self._extract_branch_handle(variable_pool.get((self.id, key))) - if handle: - return handle - - default_values = self.node_data.default_value_dict - for key in self._BRANCH_SELECTION_KEYS: - handle = self._normalize_branch_value(default_values.get(key)) - if handle: - return handle - - return None - - @staticmethod - def _extract_branch_handle(segment: Any) -> str | None: - if segment is None: - return None - - candidate = getattr(segment, "to_object", None) - raw_value = candidate() if callable(candidate) else getattr(segment, "value", None) - if raw_value is None: - return None - - return HumanInputNode._normalize_branch_value(raw_value) - - @staticmethod - def _normalize_branch_value(value: Any) -> str | None: - if value is None: - return None - - if isinstance(value, str): - stripped = value.strip() - return stripped or None - - if isinstance(value, Mapping): - for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"): - candidate = value.get(key) - if isinstance(candidate, str) and candidate: - return candidate - - return None - - def _form_to_pause_event(self, form_entity: HumanInputFormStateProtocol): - required_event = self._human_input_required_event(form_entity) - pause_requested_event = PauseRequestedEvent(reason=required_event) - return pause_requested_event - - def resolve_default_values(self) -> Mapping[str, Any]: - variable_pool = self.graph_runtime_state.variable_pool - resolved_defaults: dict[str, Any] = {} - for input in self._node_data.inputs: - if (default_value := input.default) is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - resolved_value = variable_pool.get(default_value.selector) - if resolved_value is None: - # TODO: How should we handle this? - continue - resolved_defaults[input.output_variable_name] = ( - WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value) - ) - - return resolved_defaults - - def _human_input_required_event(self, form_entity: HumanInputFormStateProtocol) -> HumanInputRequired: - node_data = self._node_data - resolved_default_values = self.resolve_default_values() - return HumanInputRequired( - form_id=form_entity.id, - form_content=form_entity.rendered_content, - inputs=node_data.inputs, - actions=node_data.user_actions, - node_id=self.id, - node_title=node_data.title, - resolved_default_values=resolved_default_values, - ) - - def _run(self) -> Generator[NodeEventBase, None, None]: - """ - Execute the human input node. - - This method will: - 1. Generate a unique form ID - 2. Create form content with variable substitution - 3. Persist the form through the configured repository - 4. Send form via configured delivery methods - 5. Suspend workflow execution - 6. Wait for form submission to resume - """ - form = self._runtime.get_form(node_id=self.id) - if form is None: - form_entity = self._runtime.create_form( - node_id=self.id, - node_data=self._node_data, - rendered_content=self.render_form_content_before_submission(), - resolved_default_values=self.resolve_default_values(), - ) - - logger.info( - "Human Input node suspended workflow for form. node_id=%s, form_id=%s", - self.id, - form_entity.id, - ) - yield self._form_to_pause_event(form_entity) - return - - if form.status in { - HumanInputFormStatus.TIMEOUT, - HumanInputFormStatus.EXPIRED, - } or form.expiration_time <= datetime.now(UTC).replace(tzinfo=None): - yield HumanInputFormTimeoutEvent( - node_title=self._node_data.title, - expiration_time=form.expiration_time, - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={self._OUTPUT_FIELD_ACTION_ID: ""}, - edge_source_handle=self._TIMEOUT_HANDLE, - ) - ) - return - - if not form.submitted: - yield self._form_to_pause_event(form) - return - - selected_action_id = form.selected_action_id - if selected_action_id is None: - raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}") - submitted_data = form.submitted_data or {} - outputs: dict[str, Any] = dict(submitted_data) - outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id - rendered_content = self.render_form_content_with_outputs( - form.rendered_content, - outputs, - self._node_data.outputs_field_names(), - ) - outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content - - action_text = self._node_data.find_action_text(selected_action_id) - - yield HumanInputFormFilledEvent( - node_title=self._node_data.title, - rendered_content=rendered_content, - action_id=selected_action_id, - action_text=action_text, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - edge_source_handle=selected_action_id, - ) - ) - - def render_form_content_before_submission(self) -> str: - """ - Process form content by substituting variables. - - This method should: - 1. Parse the form_content markdown - 2. Substitute {{#node_name.var_name#}} with actual values - 3. Keep {{#$output.field_name#}} placeholders for form inputs - """ - rendered_form_content = self.graph_runtime_state.variable_pool.convert_template( - self._node_data.form_content, - ) - return rendered_form_content.markdown - - @staticmethod - def render_form_content_with_outputs( - form_content: str, - outputs: Mapping[str, Any], - field_names: Sequence[str], - ) -> str: - """ - Replace {{#$output.xxx#}} placeholders with submitted values. - """ - rendered_content = form_content - for field_name in field_names: - placeholder = "{{#$output." + field_name + "#}}" - value = outputs.get(field_name) - if value is None: - replacement = "" - elif isinstance(value, (dict, list)): - replacement = json.dumps(value, ensure_ascii=False) - else: - replacement = str(value) - rendered_content = rendered_content.replace(placeholder, replacement) - return rendered_content - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: HumanInputNodeData, - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selectors referenced in form content and input default values. - - This method should parse: - 1. Variables referenced in form_content ({{#node_name.var_name#}}) - 2. Variables referenced in input default values - """ - return node_data.extract_variable_selector_to_variable_mapping(node_id) diff --git a/api/graphon/nodes/if_else/__init__.py b/api/graphon/nodes/if_else/__init__.py deleted file mode 100644 index afa0e8112c..0000000000 --- a/api/graphon/nodes/if_else/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .if_else_node import IfElseNode - -__all__ = ["IfElseNode"] diff --git a/api/graphon/nodes/if_else/entities.py b/api/graphon/nodes/if_else/entities.py deleted file mode 100644 index d59b782747..0000000000 --- a/api/graphon/nodes/if_else/entities.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.utils.condition.entities import Condition - - -class IfElseNodeData(BaseNodeData): - """ - If Else Node Data. - """ - - type: NodeType = BuiltinNodeTypes.IF_ELSE - - class Case(BaseModel): - """ - Case entity representing a single logical condition group - """ - - case_id: str - logical_operator: Literal["and", "or"] - conditions: list[Condition] - - logical_operator: Literal["and", "or"] | None = "and" - conditions: list[Condition] | None = Field(default=None, deprecated=True) - - cases: list[Case] | None = None diff --git a/api/graphon/nodes/if_else/if_else_node.py b/api/graphon/nodes/if_else/if_else_node.py deleted file mode 100644 index 81e934971a..0000000000 --- a/api/graphon/nodes/if_else/if_else_node.py +++ /dev/null @@ -1,124 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Literal - -from typing_extensions import deprecated - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.if_else.entities import IfElseNodeData -from graphon.runtime import VariablePool -from graphon.utils.condition.entities import Condition -from graphon.utils.condition.processor import ConditionProcessor - - -class IfElseNode(Node[IfElseNodeData]): - node_type = BuiltinNodeTypes.IF_ELSE - execution_type = NodeExecutionType.BRANCH - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run node - :return: - """ - node_inputs: dict[str, Sequence[Mapping[str, Any]]] = {"conditions": []} - - process_data: dict[str, list] = {"condition_results": []} - - input_conditions: Sequence[Mapping[str, Any]] = [] - final_result = False - selected_case_id = "false" - condition_processor = ConditionProcessor() - try: - # Check if the new cases structure is used - if self.node_data.cases: - for case in self.node_data.cases: - input_conditions, group_result, final_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=case.conditions, - operator=case.logical_operator, - ) - - process_data["condition_results"].append( - { - "group": case.model_dump(), - "results": group_result, - "final_result": final_result, - } - ) - - # Break if a case passes (logical short-circuit) - if final_result: - selected_case_id = case.case_id # Capture the ID of the passing case - break - - else: - # TODO: Remove this once all graph definitions use the `cases` structure. - # Fallback to the legacy node shape when `cases` are not defined. - input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated] - condition_processor=condition_processor, - variable_pool=self.graph_runtime_state.variable_pool, - conditions=self.node_data.conditions or [], - operator=self.node_data.logical_operator or "and", - ) - - selected_case_id = "true" if final_result else "false" - - process_data["condition_results"].append( - {"group": "default", "results": group_result, "final_result": final_result} - ) - - node_inputs["conditions"] = input_conditions - - except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_data, error=str(e) - ) - - outputs = {"result": final_result, "selected_case_id": selected_case_id} - - data = NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - process_data=process_data, - edge_source_handle=selected_case_id or "false", # Use case ID or 'default' - outputs=outputs, - ) - - return data - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IfElseNodeData, - ) -> Mapping[str, Sequence[str]]: - var_mapping: dict[str, list[str]] = {} - _ = graph_config # Explicitly mark as unused - for case in node_data.cases or []: - for condition in case.conditions: - key = f"{node_id}.#{'.'.join(condition.variable_selector)}#" - var_mapping[key] = condition.variable_selector - - return var_mapping - - -@deprecated("This function is deprecated. You should use the new cases structure.") -def _should_not_use_old_function( - *, - condition_processor: ConditionProcessor, - variable_pool: VariablePool, - conditions: list[Condition], - operator: Literal["and", "or"], -): - return condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=conditions, - operator=operator, - ) diff --git a/api/graphon/nodes/iteration/__init__.py b/api/graphon/nodes/iteration/__init__.py deleted file mode 100644 index 5bb87aaffa..0000000000 --- a/api/graphon/nodes/iteration/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .entities import IterationNodeData -from .iteration_node import IterationNode -from .iteration_start_node import IterationStartNode - -__all__ = ["IterationNode", "IterationNodeData", "IterationStartNode"] diff --git a/api/graphon/nodes/iteration/entities.py b/api/graphon/nodes/iteration/entities.py deleted file mode 100644 index 30b6e4bea8..0000000000 --- a/api/graphon/nodes/iteration/entities.py +++ /dev/null @@ -1,67 +0,0 @@ -from enum import StrEnum -from typing import Any - -from pydantic import Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base import BaseIterationNodeData, BaseIterationState - - -class ErrorHandleMode(StrEnum): - TERMINATED = "terminated" - CONTINUE_ON_ERROR = "continue-on-error" - REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output" - - -class IterationNodeData(BaseIterationNodeData): - """ - Iteration Node Data. - """ - - type: NodeType = BuiltinNodeTypes.ITERATION - parent_loop_id: str | None = None # redundant field, not used currently - iterator_selector: list[str] # variable selector - output_selector: list[str] # output selector - is_parallel: bool = False # open the parallel mode or not - parallel_nums: int = 10 # the numbers of parallel - error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error - flatten_output: bool = True # whether to flatten the output array if all elements are lists - - -class IterationStartNodeData(BaseNodeData): - """ - Iteration Start Node Data. - """ - - type: NodeType = BuiltinNodeTypes.ITERATION_START - - -class IterationState(BaseIterationState): - """ - Iteration State. - """ - - outputs: list[Any] = Field(default_factory=list) - current_output: Any = None - - class MetaData(BaseIterationState.MetaData): - """ - Data. - """ - - iterator_length: int - - def get_last_output(self) -> Any: - """ - Get last output. - """ - if self.outputs: - return self.outputs[-1] - return None - - def get_current_output(self) -> Any: - """ - Get current output. - """ - return self.current_output diff --git a/api/graphon/nodes/iteration/exc.py b/api/graphon/nodes/iteration/exc.py deleted file mode 100644 index 7b6af61b9d..0000000000 --- a/api/graphon/nodes/iteration/exc.py +++ /dev/null @@ -1,26 +0,0 @@ -class IterationNodeError(ValueError): - """Base class for iteration node errors.""" - - -class IteratorVariableNotFoundError(IterationNodeError): - """Raised when the iterator variable is not found.""" - - -class InvalidIteratorValueError(IterationNodeError): - """Raised when the iterator value is invalid.""" - - -class StartNodeIdNotFoundError(IterationNodeError): - """Raised when the start node ID is not found.""" - - -class IterationGraphNotFoundError(IterationNodeError): - """Raised when the iteration graph is not found.""" - - -class IterationIndexNotFoundError(IterationNodeError): - """Raised when the iteration index is not found.""" - - -class ChildGraphAbortedError(IterationNodeError): - """Raised when a child graph aborts and the container must stop immediately.""" diff --git a/api/graphon/nodes/iteration/iteration_node.py b/api/graphon/nodes/iteration/iteration_node.py deleted file mode 100644 index c013739653..0000000000 --- a/api/graphon/nodes/iteration/iteration_node.py +++ /dev/null @@ -1,686 +0,0 @@ -import logging -from collections.abc import Generator, Mapping, Sequence -from concurrent.futures import Future, ThreadPoolExecutor, as_completed -from contextlib import suppress -from datetime import UTC, datetime -from threading import Lock -from typing import TYPE_CHECKING, Any, NewType, cast - -from typing_extensions import TypeIs - -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.enums import ( - BuiltinNodeTypes, - NodeExecutionType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.graph_events import ( - GraphNodeEventBase, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunSucceededEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import ( - IterationFailedEvent, - IterationNextEvent, - IterationStartedEvent, - IterationSucceededEvent, - NodeEventBase, - NodeRunResult, - StreamCompletedEvent, -) -from graphon.nodes.base import LLMUsageTrackingMixin -from graphon.nodes.base.node import Node -from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from graphon.runtime import VariablePool -from graphon.variables import IntegerVariable, NoneSegment -from graphon.variables.segments import ArrayAnySegment, ArraySegment - -from .exc import ( - ChildGraphAbortedError, - InvalidIteratorValueError, - IterationGraphNotFoundError, - IterationIndexNotFoundError, - IterationNodeError, - IteratorVariableNotFoundError, - StartNodeIdNotFoundError, -) - -if TYPE_CHECKING: - from graphon.graph_engine import GraphEngine - -logger = logging.getLogger(__name__) -_DEFAULT_CHILD_ABORT_REASON = "child graph aborted" - -EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) - - -class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): - """ - Iteration Node. - """ - - node_type = BuiltinNodeTypes.ITERATION - execution_type = NodeExecutionType.CONTAINER - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return { - "type": "iteration", - "config": { - "is_parallel": False, - "parallel_nums": 10, - "error_handle_mode": ErrorHandleMode.TERMINATED, - "flatten_output": True, - }, - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # type: ignore - variable = self._get_iterator_variable() - - if self._is_empty_iteration(variable): - yield from self._handle_empty_iteration(variable) - return - - iterator_list_value = self._validate_and_get_iterator_list(variable) - inputs = {"iterator_selector": iterator_list_value} - - self._validate_start_node() - - started_at = datetime.now(UTC).replace(tzinfo=None) - iter_run_map: dict[str, float] = {} - outputs: list[object] = [] - usage_accumulator = [LLMUsage.empty_usage()] - - yield IterationStartedEvent( - start_at=started_at, - inputs=inputs, - metadata={"iteration_length": len(iterator_list_value)}, - ) - - try: - yield from self._execute_iterations( - iterator_list_value=iterator_list_value, - outputs=outputs, - iter_run_map=iter_run_map, - usage_accumulator=usage_accumulator, - ) - - self._accumulate_usage(usage_accumulator[0]) - yield from self._handle_iteration_success( - started_at=started_at, - inputs=inputs, - outputs=outputs, - iterator_list_value=iterator_list_value, - iter_run_map=iter_run_map, - usage=usage_accumulator[0], - ) - except IterationNodeError as e: - self._accumulate_usage(usage_accumulator[0]) - yield from self._handle_iteration_failure( - started_at=started_at, - inputs=inputs, - outputs=outputs, - iterator_list_value=iterator_list_value, - iter_run_map=iter_run_map, - usage=usage_accumulator[0], - error=e, - ) - - def _get_iterator_variable(self) -> ArraySegment | NoneSegment: - variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) - - if not variable: - raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found") - - if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment): - raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") - - return variable - - def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]: - return isinstance(variable, NoneSegment) or len(variable.value) == 0 - - def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]: - # Try our best to preserve the type information. - if isinstance(variable, ArraySegment): - output = variable.model_copy(update={"value": []}) - else: - output = ArrayAnySegment(value=[]) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - # TODO(QuantumGhost): is it possible to compute the type of `output` - # from graph definition? - outputs={"output": output}, - ) - ) - - def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]: - iterator_list_value = variable.to_object() - - if not isinstance(iterator_list_value, list): - raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") - - return cast(list[object], iterator_list_value) - - def _validate_start_node(self) -> None: - if not self.node_data.start_node_id: - raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") - - def _execute_iterations( - self, - iterator_list_value: Sequence[object], - outputs: list[object], - iter_run_map: dict[str, float], - usage_accumulator: list[LLMUsage], - ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: - if self.node_data.is_parallel: - # Parallel mode execution - yield from self._execute_parallel_iterations( - iterator_list_value=iterator_list_value, - outputs=outputs, - iter_run_map=iter_run_map, - usage_accumulator=usage_accumulator, - ) - else: - # Sequential mode execution - for index, item in enumerate(iterator_list_value): - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - yield IterationNextEvent(index=index) - - graph_engine = self._create_graph_engine(index, item) - - # Run the iteration - try: - yield from self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs, - graph_engine=graph_engine, - ) - finally: - self._merge_graph_engine_usage(usage_accumulator=usage_accumulator, graph_engine=graph_engine) - iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() - - def _execute_parallel_iterations( - self, - iterator_list_value: Sequence[object], - outputs: list[object], - iter_run_map: dict[str, float], - usage_accumulator: list[LLMUsage], - ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: - # Initialize outputs list with None values to maintain order - outputs.extend([None] * len(iterator_list_value)) - - # Determine the number of parallel workers - max_workers = min(self.node_data.parallel_nums, len(iterator_list_value)) - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - # Submit all iteration tasks - started_child_engines: dict[int, GraphEngine] = {} - started_child_engines_lock = Lock() - merged_usage_indexes: set[int] = set() - future_to_index: dict[ - Future[ - tuple[ - float, - list[GraphNodeEventBase], - object | None, - LLMUsage, - ] - ], - int, - ] = {} - for index, item in enumerate(iterator_list_value): - yield IterationNextEvent(index=index) - future = executor.submit( - self._execute_tracked_iteration_parallel, - index=index, - item=item, - started_child_engines=started_child_engines, - started_child_engines_lock=started_child_engines_lock, - ) - future_to_index[future] = index - - # Process completed iterations as they finish - for future in as_completed(future_to_index): - index = future_to_index[future] - try: - result = future.result() - ( - iteration_duration, - events, - output_value, - iteration_usage, - ) = result - - # Update outputs at the correct index - outputs[index] = output_value - - # Yield all events from this iteration - yield from events - - # The worker computes duration before we replay buffered events here, - # so slow downstream consumers don't inflate per-iteration timing. - iter_run_map[str(index)] = iteration_duration - - usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) - merged_usage_indexes.add(index) - - except Exception as e: - if index not in merged_usage_indexes: - self._merge_graph_engine_usage( - usage_accumulator=usage_accumulator, - graph_engine=started_child_engines.get(index), - ) - merged_usage_indexes.add(index) - if isinstance(e, ChildGraphAbortedError): - self._abort_parallel_siblings( - future_to_index=future_to_index, - current_future=future, - started_child_engines=started_child_engines, - reason=str(e) or _DEFAULT_CHILD_ABORT_REASON, - ) - self._drain_parallel_siblings( - future_to_index=future_to_index, - current_future=future, - started_child_engines=started_child_engines, - usage_accumulator=usage_accumulator, - merged_usage_indexes=merged_usage_indexes, - ) - raise e - - # Handle errors based on error_handle_mode - match self.node_data.error_handle_mode: - case ErrorHandleMode.TERMINATED: - # Cancel remaining futures and re-raise - for f in future_to_index: - if f != future: - f.cancel() - raise IterationNodeError(str(e)) - case ErrorHandleMode.CONTINUE_ON_ERROR: - outputs[index] = None - case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: - outputs[index] = None # Will be filtered later - - # Remove None values if in REMOVE_ABNORMAL_OUTPUT mode - if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: - outputs[:] = [output for output in outputs if output is not None] - - @staticmethod - def _merge_graph_engine_usage( - *, - usage_accumulator: list[LLMUsage], - graph_engine: "GraphEngine | None", - ) -> None: - if graph_engine is None: - return - usage_accumulator[0] = IterationNode._merge_usage( - usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage - ) - - def _abort_parallel_siblings( - self, - *, - future_to_index: Mapping[Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], int], - current_future: Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], - started_child_engines: Mapping[int, "GraphEngine"], - reason: str, - ) -> None: - for future, index in future_to_index.items(): - if future == current_future: - continue - - graph_engine = started_child_engines.get(index) - if graph_engine is not None: - graph_engine.request_abort(reason) - - future.cancel() - - def _drain_parallel_siblings( - self, - *, - future_to_index: Mapping[Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], int], - current_future: Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], - started_child_engines: Mapping[int, "GraphEngine"], - usage_accumulator: list[LLMUsage], - merged_usage_indexes: set[int], - ) -> None: - for future, index in future_to_index.items(): - if future == current_future: - continue - if future.cancelled(): - continue - - with suppress(Exception): - future.result() - - if index in merged_usage_indexes: - continue - - self._merge_graph_engine_usage( - usage_accumulator=usage_accumulator, - graph_engine=started_child_engines.get(index), - ) - merged_usage_indexes.add(index) - - def _execute_tracked_iteration_parallel( - self, - *, - index: int, - item: object, - started_child_engines: dict[int, "GraphEngine"], - started_child_engines_lock: Lock, - ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: - graph_engine = self._create_graph_engine(index, item) - with started_child_engines_lock: - started_child_engines[index] = graph_engine - - return self._execute_parallel_iteration_with_graph_engine( - index=index, - graph_engine=graph_engine, - ) - - def _execute_single_iteration_parallel( - self, - index: int, - item: object, - ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: - """Execute a single iteration in parallel mode and return results.""" - graph_engine = self._create_graph_engine(index, item) - return self._execute_parallel_iteration_with_graph_engine(index=index, graph_engine=graph_engine) - - def _execute_parallel_iteration_with_graph_engine( - self, - *, - index: int, - graph_engine: "GraphEngine", - ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: - """Execute a prepared child engine in parallel mode and return results.""" - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - events: list[GraphNodeEventBase] = [] - outputs_temp: list[object] = [] - - # Collect events instead of yielding them directly - for event in self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs_temp, - graph_engine=graph_engine, - ): - events.append(event) - - # Get the output value from the temporary outputs list - output_value = outputs_temp[0] if outputs_temp else None - iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() - - return ( - iteration_duration, - events, - output_value, - graph_engine.graph_runtime_state.llm_usage, - ) - - def _handle_iteration_success( - self, - started_at: datetime, - inputs: dict[str, Sequence[object]], - outputs: list[object], - iterator_list_value: Sequence[object], - iter_run_map: dict[str, float], - *, - usage: LLMUsage, - ) -> Generator[NodeEventBase, None, None]: - # Flatten the list of lists if all outputs are lists - flattened_outputs = self._flatten_outputs_if_needed(outputs) - - yield IterationSucceededEvent( - start_at=started_at, - inputs=inputs, - outputs={"output": flattened_outputs}, - steps=len(iterator_list_value), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, - }, - ) - - # Yield final success event - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": flattened_outputs}, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - ) - - def _flatten_outputs_if_needed(self, outputs: list[object]) -> list[object]: - """ - Flatten the outputs list if all elements are lists. - This maintains backward compatibility with version 1.8.1 behavior. - - If flatten_output is False, returns outputs as-is (nested structure). - If flatten_output is True (default), flattens the list if all elements are lists. - """ - # If flatten_output is disabled, return outputs as-is - if not self.node_data.flatten_output: - return outputs - - if not outputs: - return outputs - - # Check if all non-None outputs are lists - non_none_outputs: list[object] = [output for output in outputs if output is not None] - if not non_none_outputs: - return outputs - - if all(isinstance(output, list) for output in non_none_outputs): - # Flatten the list of lists - flattened: list[Any] = [] - for output in outputs: - if isinstance(output, list): - flattened.extend(output) - elif output is not None: - # This shouldn't happen based on our check, but handle it gracefully - flattened.append(output) - return flattened - - return outputs - - def _handle_iteration_failure( - self, - started_at: datetime, - inputs: dict[str, Sequence[object]], - outputs: list[object], - iterator_list_value: Sequence[object], - iter_run_map: dict[str, float], - *, - usage: LLMUsage, - error: IterationNodeError, - ) -> Generator[NodeEventBase, None, None]: - # Flatten the list of lists if all outputs are lists (even in failure case) - flattened_outputs = self._flatten_outputs_if_needed(outputs) - - yield IterationFailedEvent( - start_at=started_at, - inputs=inputs, - outputs={"output": flattened_outputs}, - steps=len(iterator_list_value), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, - }, - error=str(error), - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(error), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IterationNodeData, - ) -> Mapping[str, Sequence[str]]: - variable_mapping: dict[str, Sequence[str]] = { - f"{node_id}.input_selector": node_data.iterator_selector, - } - iteration_node_ids = set() - - # Find all nodes that belong to this loop - nodes = graph_config.get("nodes", []) - for node in nodes: - node_config_data = node.get("data", {}) - if node_config_data.get("iteration_id") == node_id: - in_iteration_node_id = node.get("id") - if in_iteration_node_id: - iteration_node_ids.add(in_iteration_node_id) - - # Get node configs from graph_config instead of non-existent node_id_config_mapping - node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} - for sub_node_id, sub_node_config in node_configs.items(): - if sub_node_config.get("data", {}).get("iteration_id") != node_id: - continue - - # variable selector to variable mapping - try: - typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) - node_type = typed_sub_node_config["data"].type - node_mapping = Node.get_node_type_classes_mapping() - if node_type not in node_mapping: - continue - node_version = str(typed_sub_node_config["data"].version) - node_cls = node_mapping[node_type][node_version] - - sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=typed_sub_node_config - ) - sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) - except NotImplementedError: - sub_node_variable_mapping = {} - - # remove iteration variables - sub_node_variable_mapping = { - sub_node_id + "." + key: value - for key, value in sub_node_variable_mapping.items() - if value[0] != node_id - } - - variable_mapping.update(sub_node_variable_mapping) - - # remove variable out from iteration - variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids} - - return variable_mapping - - def _append_iteration_info_to_event( - self, - event: GraphNodeEventBase, - iter_run_index: int, - ): - event.in_iteration_id = self._node_id - iter_metadata = { - WorkflowNodeExecutionMetadataKey.ITERATION_ID: self._node_id, - WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index, - } - - current_metadata = event.node_run_result.metadata - if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata: - event.node_run_result.metadata = {**current_metadata, **iter_metadata} - - def _run_single_iter( - self, - *, - variable_pool: VariablePool, - outputs: list[object], - graph_engine: "GraphEngine", - ) -> Generator[GraphNodeEventBase, None, None]: - rst = graph_engine.run() - # get current iteration index - index_variable = variable_pool.get([self._node_id, "index"]) - if not isinstance(index_variable, IntegerVariable): - raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found") - current_index = index_variable.value - for event in rst: - if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.ITERATION_START: - continue - - if isinstance(event, GraphNodeEventBase): - self._append_iteration_info_to_event(event=event, iter_run_index=current_index) - yield event - elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)): - result = variable_pool.get(self.node_data.output_selector) - if result is None: - outputs.append(None) - else: - outputs.append(result.to_object()) - return - elif isinstance(event, GraphRunAbortedEvent): - raise ChildGraphAbortedError(event.reason or _DEFAULT_CHILD_ABORT_REASON) - elif isinstance(event, GraphRunFailedEvent): - match self.node_data.error_handle_mode: - case ErrorHandleMode.TERMINATED: - raise IterationNodeError(event.error) - case ErrorHandleMode.CONTINUE_ON_ERROR: - outputs.append(None) - return - case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: - return - - def _create_graph_engine(self, index: int, item: object): - from graphon.entities import GraphInitParams - from graphon.runtime import ChildGraphNotFoundError - - # Create GraphInitParams for child graph execution. - graph_init_params = GraphInitParams( - workflow_id=self.workflow_id, - graph_config=self.graph_config, - run_context=self.run_context, - call_depth=self.workflow_call_depth, - ) - # Create a deep copy of the variable pool for each iteration - variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True) - - # append iteration variable (item, index) to variable pool - variable_pool_copy.add([self._node_id, "index"], index) - variable_pool_copy.add([self._node_id, "item"], item) - root_node_id = self.node_data.start_node_id - if root_node_id is None: - raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") - - try: - return self.graph_runtime_state.create_child_engine( - workflow_id=self.workflow_id, - graph_init_params=graph_init_params, - root_node_id=root_node_id, - variable_pool=variable_pool_copy, - ) - except ChildGraphNotFoundError as exc: - raise IterationGraphNotFoundError("iteration graph not found") from exc diff --git a/api/graphon/nodes/iteration/iteration_start_node.py b/api/graphon/nodes/iteration/iteration_start_node.py deleted file mode 100644 index 3a44d3d81d..0000000000 --- a/api/graphon/nodes/iteration/iteration_start_node.py +++ /dev/null @@ -1,22 +0,0 @@ -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.iteration.entities import IterationStartNodeData - - -class IterationStartNode(Node[IterationStartNodeData]): - """ - Iteration Start Node. - """ - - node_type = BuiltinNodeTypes.ITERATION_START - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run the node. - """ - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/graphon/nodes/list_operator/__init__.py b/api/graphon/nodes/list_operator/__init__.py deleted file mode 100644 index 1877586ef4..0000000000 --- a/api/graphon/nodes/list_operator/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .node import ListOperatorNode - -__all__ = ["ListOperatorNode"] diff --git a/api/graphon/nodes/list_operator/entities.py b/api/graphon/nodes/list_operator/entities.py deleted file mode 100644 index 0db1c75cdd..0000000000 --- a/api/graphon/nodes/list_operator/entities.py +++ /dev/null @@ -1,71 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum - -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - - -class FilterOperator(StrEnum): - # string conditions - CONTAINS = "contains" - START_WITH = "start with" - END_WITH = "end with" - IS = "is" - IN = "in" - EMPTY = "empty" - NOT_CONTAINS = "not contains" - IS_NOT = "is not" - NOT_IN = "not in" - NOT_EMPTY = "not empty" - # number conditions - EQUAL = "=" - NOT_EQUAL = "≠" - LESS_THAN = "<" - GREATER_THAN = ">" - GREATER_THAN_OR_EQUAL = "≥" - LESS_THAN_OR_EQUAL = "≤" - - -class Order(StrEnum): - ASC = "asc" - DESC = "desc" - - -class FilterCondition(BaseModel): - key: str = "" - comparison_operator: FilterOperator = FilterOperator.CONTAINS - # the value is bool if the filter operator is comparing with - # a boolean constant. - value: str | Sequence[str] | bool = "" - - -class FilterBy(BaseModel): - enabled: bool = False - conditions: Sequence[FilterCondition] = Field(default_factory=list) - - -class OrderByConfig(BaseModel): - enabled: bool = False - key: str = "" - value: Order = Order.ASC - - -class Limit(BaseModel): - enabled: bool = False - size: int = -1 - - -class ExtractConfig(BaseModel): - enabled: bool = False - serial: str = "1" - - -class ListOperatorNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.LIST_OPERATOR - variable: Sequence[str] = Field(default_factory=list) - filter_by: FilterBy - order_by: OrderByConfig - limit: Limit - extract_by: ExtractConfig = Field(default_factory=ExtractConfig) diff --git a/api/graphon/nodes/list_operator/exc.py b/api/graphon/nodes/list_operator/exc.py deleted file mode 100644 index f88aa0be29..0000000000 --- a/api/graphon/nodes/list_operator/exc.py +++ /dev/null @@ -1,16 +0,0 @@ -class ListOperatorError(ValueError): - """Base class for all ListOperator errors.""" - - pass - - -class InvalidFilterValueError(ListOperatorError): - pass - - -class InvalidKeyError(ListOperatorError): - pass - - -class InvalidConditionError(ListOperatorError): - pass diff --git a/api/graphon/nodes/list_operator/node.py b/api/graphon/nodes/list_operator/node.py deleted file mode 100644 index dad17a8f4a..0000000000 --- a/api/graphon/nodes/list_operator/node.py +++ /dev/null @@ -1,345 +0,0 @@ -from collections.abc import Callable, Sequence -from typing import Any, TypeAlias, TypeVar - -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.file import File -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment -from graphon.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment - -from .entities import FilterOperator, ListOperatorNodeData, Order -from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError - -_SUPPORTED_TYPES_TUPLE = ( - ArrayFileSegment, - ArrayNumberSegment, - ArrayStringSegment, - ArrayBooleanSegment, -) -_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment - - -_T = TypeVar("_T") - - -def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]: - """Returns the negation of a given filter function. If the original filter - returns `True` for a value, the negated filter will return `False`, and vice versa. - """ - - def wrapper(value: _T) -> bool: - return not filter_(value) - - return wrapper - - -class ListOperatorNode(Node[ListOperatorNodeData]): - node_type = BuiltinNodeTypes.LIST_OPERATOR - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self): - inputs: dict[str, Sequence[object]] = {} - process_data: dict[str, Sequence[object]] = {} - outputs: dict[str, Any] = {} - - variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) - if variable is None: - error_message = f"Variable not found for selector: {self.node_data.variable}" - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs - ) - if not variable.value: - inputs = {"variable": []} - process_data = {"variable": []} - if isinstance(variable, ArraySegment): - result = variable.model_copy(update={"value": []}) - else: - result = ArrayAnySegment(value=[]) - outputs = {"result": result, "first_record": None, "last_record": None} - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) - if not isinstance(variable, _SUPPORTED_TYPES_TUPLE): - error_message = f"Variable {self.node_data.variable} is not an array type, actual type: {type(variable)}" - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs - ) - - if isinstance(variable, ArrayFileSegment): - inputs = {"variable": [item.to_dict() for item in variable.value]} - process_data["variable"] = [item.to_dict() for item in variable.value] - else: - inputs = {"variable": variable.value} - process_data["variable"] = variable.value - - try: - # Filter - if self.node_data.filter_by.enabled: - variable = self._apply_filter(variable) - - # Extract - if self.node_data.extract_by.enabled: - variable = self._extract_slice(variable) - - # Order - if self.node_data.order_by.enabled: - variable = self._apply_order(variable) - - # Slice - if self.node_data.limit.enabled: - variable = self._apply_slice(variable) - - outputs = { - "result": variable, - "first_record": variable.value[0] if variable.value else None, - "last_record": variable.value[-1] if variable.value else None, - } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) - except ListOperatorError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) - - def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - filter_func: Callable[[Any], bool] - result: list[Any] = [] - for condition in self.node_data.filter_by.conditions: - if isinstance(variable, ArrayStringSegment): - if not isinstance(condition.value, str): - raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayNumberSegment): - if not isinstance(condition.value, str): - raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayFileSegment): - if isinstance(condition.value, str): - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - elif isinstance(condition.value, bool): - raise ValueError(f"File filter expects a string value, got {type(condition.value)}") - else: - value = condition.value - filter_func = _get_file_filter_func( - key=condition.key, - condition=condition.comparison_operator, - value=value, - ) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - else: - if not isinstance(condition.value, bool): - raise ValueError(f"Boolean filter expects a boolean value, got {type(condition.value)}") - filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - return variable - - def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)): - result = sorted(variable.value, reverse=self.node_data.order_by.value == Order.DESC) - variable = variable.model_copy(update={"value": result}) - else: - result = _order_file( - order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value - ) - variable = variable.model_copy(update={"value": result}) - - return variable - - def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - result = variable.value[: self.node_data.limit.size] - return variable.model_copy(update={"value": result}) - - def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) - if value < 1: - raise ValueError(f"Invalid serial index: must be >= 1, got {value}") - if value > len(variable.value): - raise InvalidKeyError(f"Invalid serial index: must be <= {len(variable.value)}, got {value}") - value -= 1 - result = variable.value[value] - return variable.model_copy(update={"value": [result]}) - - -def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: - match key: - case "size": - return lambda x: x.size - case _: - raise InvalidKeyError(f"Invalid key: {key}") - - -def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: - match key: - case "name": - return lambda x: x.filename or "" - case "type": - return lambda x: str(x.type) - case "extension": - return lambda x: x.extension or "" - case "mime_type": - return lambda x: x.mime_type or "" - case "transfer_method": - return lambda x: str(x.transfer_method) - case "url": - return lambda x: x.remote_url or "" - case "related_id": - return lambda x: x.related_id or "" - case _: - raise InvalidKeyError(f"Invalid key: {key}") - - -def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]: - match condition: - case "contains": - return _contains(value) - case "start with": - return _startswith(value) - case "end with": - return _endswith(value) - case "is": - return _is(value) - case "in": - return _in(value) - case "empty": - return lambda x: x == "" - case "not contains": - return _negation(_contains(value)) - case "is not": - return _negation(_is(value)) - case "not in": - return _negation(_in(value)) - case "not empty": - return lambda x: x != "" - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]: - match condition: - case "in": - return _in(value) - case "not in": - return _negation(_in(value)) - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]: - match condition: - case "=": - return _eq(value) - case "≠": - return _ne(value) - case "<": - return _lt(value) - case "≤": - return _le(value) - case ">": - return _gt(value) - case "≥": - return _ge(value) - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]: - match condition: - case FilterOperator.IS: - return _is(value) - case FilterOperator.IS_NOT: - return _negation(_is(value)) - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: - if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str): - extract_func = _get_file_extract_string_func(key=key) - return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) - if key in {"type", "transfer_method"}: - extract_func = _get_file_extract_string_func(key=key) - return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x)) - elif key == "size" and isinstance(value, str): - extract_number = _get_file_extract_number_func(key=key) - return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x)) - else: - raise InvalidKeyError(f"Invalid key: {key}") - - -def _contains(value: str) -> Callable[[str], bool]: - return lambda x: value in x - - -def _startswith(value: str) -> Callable[[str], bool]: - return lambda x: x.startswith(value) - - -def _endswith(value: str) -> Callable[[str], bool]: - return lambda x: x.endswith(value) - - -def _is(value: _T) -> Callable[[_T], bool]: - return lambda x: x == value - - -def _in(value: str | Sequence[str]) -> Callable[[str], bool]: - return lambda x: x in value - - -def _eq(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x == value - - -def _ne(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x != value - - -def _lt(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x < value - - -def _le(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x <= value - - -def _gt(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x > value - - -def _ge(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x >= value - - -def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]): - extract_func: Callable[[File], Any] - if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url", "related_id"}: - extract_func = _get_file_extract_string_func(key=order_by) - return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC) - elif order_by == "size": - extract_func = _get_file_extract_number_func(key=order_by) - return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC) - else: - raise InvalidKeyError(f"Invalid order key: {order_by}") diff --git a/api/graphon/nodes/llm/__init__.py b/api/graphon/nodes/llm/__init__.py deleted file mode 100644 index f7bc713f63..0000000000 --- a/api/graphon/nodes/llm/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from .entities import ( - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from .node import LLMNode - -__all__ = [ - "LLMNode", - "LLMNodeChatModelMessage", - "LLMNodeCompletionModelPromptTemplate", - "LLMNodeData", - "ModelConfig", - "VisionConfig", -] diff --git a/api/graphon/nodes/llm/entities.py b/api/graphon/nodes/llm/entities.py deleted file mode 100644 index 196152548c..0000000000 --- a/api/graphon/nodes/llm/entities.py +++ /dev/null @@ -1,100 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Literal - -from pydantic import BaseModel, Field, field_validator - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.model_runtime.entities import ImagePromptMessageContent, LLMMode -from graphon.nodes.base.entities import VariableSelector -from graphon.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig - - -class ModelConfig(BaseModel): - provider: str - name: str - mode: LLMMode - completion_params: dict[str, Any] = Field(default_factory=dict) - - -class ContextConfig(BaseModel): - enabled: bool - variable_selector: list[str] | None = None - - -class VisionConfigOptions(BaseModel): - variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"]) - detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH - - -class VisionConfig(BaseModel): - enabled: bool = False - configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions) - - @field_validator("configs", mode="before") - @classmethod - def convert_none_configs(cls, v: Any): - if v is None: - return VisionConfigOptions() - return v - - -class PromptConfig(BaseModel): - jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list) - - @field_validator("jinja2_variables", mode="before") - @classmethod - def convert_none_jinja2_variables(cls, v: Any): - if v is None: - return [] - return v - - -class LLMNodeChatModelMessage(ChatModelMessage): - text: str = "" - jinja2_text: str | None = None - - -class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): - jinja2_text: str | None = None - - -class LLMNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.LLM - model: ModelConfig - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - prompt_config: PromptConfig = Field(default_factory=PromptConfig) - memory: MemoryConfig | None = None - context: ContextConfig - vision: VisionConfig = Field(default_factory=VisionConfig) - structured_output: Mapping[str, Any] | None = None - # We used 'structured_output_enabled' in the past, but it's not a good name. - structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") - reasoning_format: Literal["separated", "tagged"] = Field( - # Keep tagged as default for backward compatibility - default="tagged", - description=( - """ - Strategy for handling model reasoning output. - - separated: Return clean text (without tags) + reasoning_content field. - Recommended for new workflows. Enables safe downstream parsing and - workflow variable access: {{#node_id.reasoning_content#}} - - tagged : Return original text (with tags) + reasoning_content field. - Maintains full backward compatibility while still providing reasoning_content - for workflow automation. Frontend thinking panels work as before. - """ - ), - ) - - @field_validator("prompt_config", mode="before") - @classmethod - def convert_none_prompt_config(cls, v: Any): - if v is None: - return PromptConfig() - return v - - @property - def structured_output_enabled(self) -> bool: - return self.structured_output_switch_on and self.structured_output is not None diff --git a/api/graphon/nodes/llm/exc.py b/api/graphon/nodes/llm/exc.py deleted file mode 100644 index 4d16095296..0000000000 --- a/api/graphon/nodes/llm/exc.py +++ /dev/null @@ -1,45 +0,0 @@ -class LLMNodeError(ValueError): - """Base class for LLM Node errors.""" - - -class VariableNotFoundError(LLMNodeError): - """Raised when a required variable is not found.""" - - -class InvalidContextStructureError(LLMNodeError): - """Raised when the context structure is invalid.""" - - -class InvalidVariableTypeError(LLMNodeError): - """Raised when the variable type is invalid.""" - - -class ModelNotExistError(LLMNodeError): - """Raised when the specified model does not exist.""" - - -class LLMModeRequiredError(LLMNodeError): - """Raised when LLM mode is required but not provided.""" - - -class NoPromptFoundError(LLMNodeError): - """Raised when no prompt is found in the LLM configuration.""" - - -class TemplateTypeNotSupportError(LLMNodeError): - def __init__(self, *, type_name: str): - super().__init__(f"Prompt type {type_name} is not supported.") - - -class MemoryRolePrefixRequiredError(LLMNodeError): - """Raised when memory role prefix is required for completion model.""" - - -class FileTypeNotSupportError(LLMNodeError): - def __init__(self, *, type_name: str): - super().__init__(f"{type_name} type is not supported by this model") - - -class UnsupportedPromptContentTypeError(LLMNodeError): - def __init__(self, *, type_name: str): - super().__init__(f"Prompt content type {type_name} is not supported.") diff --git a/api/graphon/nodes/llm/file_saver.py b/api/graphon/nodes/llm/file_saver.py deleted file mode 100644 index 0bedb42f3a..0000000000 --- a/api/graphon/nodes/llm/file_saver.py +++ /dev/null @@ -1,139 +0,0 @@ -import mimetypes -import typing as tp - -from graphon.file import File, FileTransferMethod, FileType -from graphon.file.constants import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE -from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol - - -class LLMFileSaver(tp.Protocol): - """LLMFileSaver is responsible for save multimodal output returned by - LLM. - """ - - def save_binary_string( - self, - data: bytes, - mime_type: str, - file_type: FileType, - extension_override: str | None = None, - ) -> File: - """save_binary_string saves the inline file data returned by LLM. - - Currently (2025-04-30), only some of Google Gemini models will return - multimodal output as inline data. - - :param data: the contents of the file - :param mime_type: the media type of the file, specified by rfc6838 - (https://datatracker.ietf.org/doc/html/rfc6838) - :param file_type: The file type of the inline file. - :param extension_override: Override the auto-detected file extension while saving this file. - - The default value is `None`, which means do not override the file extension and guessing it - from the `mime_type` attribute while saving the file. - - Setting it to values other than `None` means override the file's extension, and - will bypass the extension guessing saving the file. - - Specially, setting it to empty string (`""`) will leave the file extension empty. - - When it is not `None` or empty string (`""`), it should be a string beginning with a - dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py` - and `tar.gz` are not. - """ - raise NotImplementedError() - - def save_remote_url(self, url: str, file_type: FileType) -> File: - """save_remote_url saves the file from a remote url returned by LLM. - - Currently (2025-04-30), no model returns multimodel output as a url. - - :param url: the url of the file. - :param file_type: the file type of the file, check `FileType` enum for reference. - """ - raise NotImplementedError() - - -class FileSaverImpl(LLMFileSaver): - _tool_file_manager: ToolFileManagerProtocol - _file_reference_factory: FileReferenceFactoryProtocol - - def __init__( - self, - *, - tool_file_manager: ToolFileManagerProtocol, - file_reference_factory: FileReferenceFactoryProtocol, - http_client: HttpClientProtocol, - ): - self._tool_file_manager = tool_file_manager - self._file_reference_factory = file_reference_factory - self._http_client = http_client - - def save_remote_url(self, url: str, file_type: FileType) -> File: - http_response = self._http_client.get(url) - http_response.raise_for_status() - data = http_response.content - mime_type_from_header = http_response.headers.get("Content-Type") - mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header) - return self.save_binary_string(data, mime_type, file_type, extension_override=extension) - - def save_binary_string( - self, - data: bytes, - mime_type: str, - file_type: FileType, - extension_override: str | None = None, - ) -> File: - tool_file = self._tool_file_manager.create_file_by_raw( - file_binary=data, - mimetype=mime_type, - ) - extension_override = _validate_extension_override(extension_override) - extension = _get_extension(mime_type, extension_override) - return self._file_reference_factory.build_from_mapping( - mapping={ - "type": file_type, - "transfer_method": FileTransferMethod.TOOL_FILE, - "filename": tool_file.name, - "extension": extension, - "mime_type": mime_type, - "size": len(data), - "tool_file_id": str(tool_file.id), - "related_id": str(tool_file.id), - "storage_key": tool_file.file_key, - } - ) - - -def _get_extension(mime_type: str, extension_override: str | None = None) -> str: - """get_extension return the extension of file. - - If the `extension_override` parameter is set, this function should honor it and - return its value. - """ - if extension_override is not None: - return extension_override - return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION - - -def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]: - """_extract_content_type_and_extension tries to - guess content type of file from url and `Content-Type` header in response. - """ - if content_type_header: - extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION - return content_type_header, extension - content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE - extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION - return content_type, extension - - -def _validate_extension_override(extension_override: str | None) -> str | None: - # `extension_override` is allow to be `None or `""`. - if extension_override is None: - return None - if extension_override == "": - return "" - if not extension_override.startswith("."): - raise ValueError("extension_override should start with '.' if not None or empty.", extension_override) - return extension_override diff --git a/api/graphon/nodes/llm/llm_utils.py b/api/graphon/nodes/llm/llm_utils.py deleted file mode 100644 index 11a1d83a9d..0000000000 --- a/api/graphon/nodes/llm/llm_utils.py +++ /dev/null @@ -1,545 +0,0 @@ -from __future__ import annotations - -import json -import logging -import re -from collections.abc import Mapping, Sequence -from typing import Any - -from graphon.file import FileType, file_manager -from graphon.file.models import File -from graphon.model_runtime.entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - PromptMessageRole, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageContentUnionTypes, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.nodes.base.entities import VariableSelector -from graphon.runtime import VariablePool -from graphon.template_rendering import Jinja2TemplateRenderer -from graphon.variables import ArrayFileSegment, FileSegment -from graphon.variables.segments import ArrayAnySegment, NoneSegment - -from .entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig -from .exc import ( - InvalidVariableTypeError, - MemoryRolePrefixRequiredError, - NoPromptFoundError, - TemplateTypeNotSupportError, -) -from .runtime_protocols import PreparedLLMProtocol - -CONTEXT_PLACEHOLDER = "{{#context#}}" - -logger = logging.getLogger(__name__) - -VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}") -MAX_RESOLVED_VALUE_LENGTH = 1024 - - -def fetch_model_schema(*, model_instance: PreparedLLMProtocol) -> AIModelEntity: - model_schema = model_instance.get_model_schema() - if not model_schema: - raise ValueError(f"Model schema not found for {getattr(model_instance, 'model_name', 'unknown model')}") - return model_schema - - -def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence[File]: - variable = variable_pool.get(selector) - if variable is None: - return [] - elif isinstance(variable, FileSegment): - return [variable.value] - elif isinstance(variable, ArrayFileSegment): - return variable.value - elif isinstance(variable, NoneSegment | ArrayAnySegment): - return [] - raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") - - -def convert_history_messages_to_text( - *, - history_messages: Sequence[PromptMessage], - human_prefix: str, - ai_prefix: str, -) -> str: - string_messages: list[str] = [] - for message in history_messages: - if message.role == PromptMessageRole.USER: - role = human_prefix - elif message.role == PromptMessageRole.ASSISTANT: - role = ai_prefix - else: - continue - - if isinstance(message.content, list): - content_parts = [] - for content in message.content: - if isinstance(content, TextPromptMessageContent): - content_parts.append(content.data) - elif isinstance(content, ImagePromptMessageContent): - content_parts.append("[image]") - - inner_msg = "\n".join(content_parts) - string_messages.append(f"{role}: {inner_msg}") - else: - string_messages.append(f"{role}: {message.content}") - - return "\n".join(string_messages) - - -def fetch_memory_text( - *, - memory: PromptMessageMemory, - max_token_limit: int, - message_limit: int | None = None, - human_prefix: str = "Human", - ai_prefix: str = "Assistant", -) -> str: - history_messages = memory.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=message_limit, - ) - return convert_history_messages_to_text( - history_messages=history_messages, - human_prefix=human_prefix, - ai_prefix=ai_prefix, - ) - - -def fetch_prompt_messages( - *, - sys_query: str | None = None, - sys_files: Sequence[File], - context: str = "", - memory: PromptMessageMemory | None = None, - model_instance: PreparedLLMProtocol, - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, - stop: Sequence[str] | None = None, - memory_config: MemoryConfig | None = None, - vision_enabled: bool = False, - vision_detail: ImagePromptMessageContent.DETAIL, - variable_pool: VariablePool, - jinja2_variables: Sequence[VariableSelector], - context_files: list[File] | None = None, - template_renderer: Jinja2TemplateRenderer | None = None, -) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - prompt_messages: list[PromptMessage] = [] - model_schema = fetch_model_schema(model_instance=model_instance) - - if isinstance(prompt_template, list): - prompt_messages.extend( - handle_list_messages( - messages=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail, - template_renderer=template_renderer, - ) - ) - - prompt_messages.extend( - handle_memory_chat_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - ) - - if sys_query: - prompt_messages.extend( - handle_list_messages( - messages=[ - LLMNodeChatModelMessage( - text=sys_query, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context="", - jinja2_variables=[], - variable_pool=variable_pool, - vision_detail_config=vision_detail, - template_renderer=template_renderer, - ) - ) - elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - prompt_messages.extend( - handle_completion_template( - template=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - template_renderer=template_renderer, - ) - ) - - memory_text = handle_memory_completion_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - prompt_content = prompt_messages[0].content - if isinstance(prompt_content, str): - prompt_content = str(prompt_content) - if "#histories#" in prompt_content: - prompt_content = prompt_content.replace("#histories#", memory_text) - else: - prompt_content = memory_text + "\n" + prompt_content - prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - if "#histories#" in content_item.data: - content_item.data = content_item.data.replace("#histories#", memory_text) - else: - content_item.data = memory_text + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - - if sys_query: - if isinstance(prompt_content, str): - prompt_messages[0].content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - content_item.data = sys_query + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - else: - raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) - - _append_file_prompts( - prompt_messages=prompt_messages, - files=sys_files, - vision_enabled=vision_enabled, - vision_detail=vision_detail, - ) - _append_file_prompts( - prompt_messages=prompt_messages, - files=context_files or [], - vision_enabled=vision_enabled, - vision_detail=vision_detail, - ) - - filtered_prompt_messages: list[PromptMessage] = [] - for prompt_message in prompt_messages: - if isinstance(prompt_message.content, list): - prompt_message_content: list[PromptMessageContentUnionTypes] = [] - for content_item in prompt_message.content: - if not model_schema.features: - if content_item.type == PromptMessageContentType.TEXT: - prompt_message_content.append(content_item) - continue - - if ( - ( - content_item.type == PromptMessageContentType.IMAGE - and ModelFeature.VISION not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.DOCUMENT - and ModelFeature.DOCUMENT not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.VIDEO - and ModelFeature.VIDEO not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.AUDIO - and ModelFeature.AUDIO not in model_schema.features - ) - ): - continue - prompt_message_content.append(content_item) - if not prompt_message_content: - continue - if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: - prompt_message.content = prompt_message_content[0].data - else: - prompt_message.content = prompt_message_content - filtered_prompt_messages.append(prompt_message) - elif not prompt_message.is_empty(): - filtered_prompt_messages.append(prompt_message) - - if len(filtered_prompt_messages) == 0: - raise NoPromptFoundError( - "No prompt found in the LLM configuration. Please ensure a prompt is properly configured before proceeding." - ) - - return filtered_prompt_messages, stop - - -def handle_list_messages( - *, - messages: Sequence[LLMNodeChatModelMessage], - context: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - vision_detail_config: ImagePromptMessageContent.DETAIL, - template_renderer: Jinja2TemplateRenderer | None = None, -) -> Sequence[PromptMessage]: - prompt_messages: list[PromptMessage] = [] - for message in messages: - if message.edition_type == "jinja2": - result_text = render_jinja2_message( - template=message.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - template_renderer=template_renderer, - ) - prompt_messages.append( - combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], - role=message.role, - ) - ) - continue - - template = message.text.replace(CONTEXT_PLACEHOLDER, context) - segment_group = variable_pool.convert_template(template) - file_contents: list[PromptMessageContentUnionTypes] = [] - for segment in segment_group.value: - if isinstance(segment, ArrayFileSegment): - for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config) - ) - elif isinstance(segment, FileSegment): - file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config) - ) - - if segment_group.text: - prompt_messages.append( - combine_message_content_with_role( - contents=[TextPromptMessageContent(data=segment_group.text)], - role=message.role, - ) - ) - if file_contents: - prompt_messages.append(combine_message_content_with_role(contents=file_contents, role=message.role)) - - return prompt_messages - - -def render_jinja2_message( - *, - template: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - template_renderer: Jinja2TemplateRenderer | None = None, -) -> str: - if not template: - return "" - if template_renderer is None: - raise ValueError("template_renderer is required for jinja2 prompt rendering") - - jinja2_inputs: dict[str, Any] = {} - for jinja2_variable in jinja2_variables: - variable = variable_pool.get(jinja2_variable.value_selector) - jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" - return template_renderer.render_template(template, jinja2_inputs) - - -def handle_completion_template( - *, - template: LLMNodeCompletionModelPromptTemplate, - context: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - template_renderer: Jinja2TemplateRenderer | None = None, -) -> Sequence[PromptMessage]: - if template.edition_type == "jinja2": - result_text = render_jinja2_message( - template=template.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - template_renderer=template_renderer, - ) - else: - template_text = template.text.replace(CONTEXT_PLACEHOLDER, context) - result_text = variable_pool.convert_template(template_text).text - return [ - combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], - role=PromptMessageRole.USER, - ) - ] - - -def combine_message_content_with_role( - *, - contents: str | list[PromptMessageContentUnionTypes] | None = None, - role: PromptMessageRole, -) -> PromptMessage: - match role: - case PromptMessageRole.USER: - return UserPromptMessage(content=contents) - case PromptMessageRole.ASSISTANT: - return AssistantPromptMessage(content=contents) - case PromptMessageRole.SYSTEM: - return SystemPromptMessage(content=contents) - case _: - raise NotImplementedError(f"Role {role} is not supported") - - -def calculate_rest_token( - *, - prompt_messages: list[PromptMessage], - model_instance: PreparedLLMProtocol, -) -> int: - rest_tokens = 2000 - runtime_model_schema = fetch_model_schema(model_instance=model_instance) - runtime_model_parameters = model_instance.parameters - - model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in runtime_model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - runtime_model_parameters.get(parameter_rule.name) - or runtime_model_parameters.get(str(parameter_rule.use_template)) - or 0 - ) - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - -def handle_memory_chat_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: PreparedLLMProtocol, -) -> Sequence[PromptMessage]: - if not memory or not memory_config: - return [] - rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance) - return memory.get_history_prompt_messages( - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - ) - - -def handle_memory_completion_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: PreparedLLMProtocol, -) -> str: - if not memory or not memory_config: - return "" - - rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance) - if not memory_config.role_prefix: - raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") - - return fetch_memory_text( - memory=memory, - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - human_prefix=memory_config.role_prefix.user, - ai_prefix=memory_config.role_prefix.assistant, - ) - - -def _append_file_prompts( - *, - prompt_messages: list[PromptMessage], - files: Sequence[File], - vision_enabled: bool, - vision_detail: ImagePromptMessageContent.DETAIL, -) -> None: - if not vision_enabled or not files: - return - - file_prompts = [file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in files] - if ( - prompt_messages - and isinstance(prompt_messages[-1], UserPromptMessage) - and isinstance(prompt_messages[-1].content, list) - ): - existing_contents = prompt_messages[-1].content - assert isinstance(existing_contents, list) - prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents) - else: - prompt_messages.append(UserPromptMessage(content=file_prompts)) - - -def _coerce_resolved_value(raw: str) -> int | float | bool | str: - """Try to restore the original type from a resolved template string. - - Variable references are always resolved to text, but completion params may - expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to - the ``temperature`` parameter). This helper attempts a JSON parse so that - ``"0.7"`` → ``0.7``, ``"true"`` → ``True``, etc. Plain strings that are not - valid JSON literals are returned as-is. - """ - stripped = raw.strip() - if not stripped: - return raw - - try: - parsed: object = json.loads(stripped) - except (json.JSONDecodeError, ValueError): - return raw - - if isinstance(parsed, (int, float, bool)): - return parsed - return raw - - -def resolve_completion_params_variables( - completion_params: Mapping[str, Any], - variable_pool: VariablePool, -) -> dict[str, Any]: - """Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params. - - Security notes: - - Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to - prevent denial-of-service through excessively large variable payloads. - - This follows the same ``VariablePool.convert_template`` pattern used across - Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream - model plugin receives these values as structured JSON key-value pairs — they - are never concatenated into raw HTTP headers or SQL queries. - - Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are - restored to their native type rather than sent as a bare string. - """ - resolved: dict[str, Any] = {} - for key, value in completion_params.items(): - if isinstance(value, str) and VARIABLE_PATTERN.search(value): - segment_group = variable_pool.convert_template(value) - text = segment_group.text - if len(text) > MAX_RESOLVED_VALUE_LENGTH: - logger.warning( - "Resolved value for param '%s' truncated from %d to %d chars", - key, - len(text), - MAX_RESOLVED_VALUE_LENGTH, - ) - text = text[:MAX_RESOLVED_VALUE_LENGTH] - resolved[key] = _coerce_resolved_value(text) - else: - resolved[key] = value - return resolved diff --git a/api/graphon/nodes/llm/node.py b/api/graphon/nodes/llm/node.py deleted file mode 100644 index 4de2a95465..0000000000 --- a/api/graphon/nodes/llm/node.py +++ /dev/null @@ -1,1372 +0,0 @@ -from __future__ import annotations - -import base64 -import io -import json -import logging -import re -import time -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, cast - -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ( - BuiltinNodeTypes, - NodeType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import File, FileType, file_manager -from graphon.model_runtime.entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, - LLMStructuredOutput, - LLMUsage, -) -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageContentUnionTypes, - PromptMessageRole, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.node_events import ( - ModelInvokeCompletedEvent, - NodeEventBase, - NodeRunResult, - RunRetrieverResourceEvent, - StreamChunkEvent, - StreamCompletedEvent, -) -from graphon.nodes.base.entities import VariableSelector -from graphon.nodes.base.node import Node -from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from graphon.nodes.llm.runtime_protocols import ( - PreparedLLMProtocol, - PromptMessageSerializerProtocol, - RetrieverAttachmentLoaderProtocol, -) -from graphon.nodes.protocols import HttpClientProtocol -from graphon.prompt_entities import CompletionModelPromptTemplate, MemoryConfig -from graphon.runtime import VariablePool -from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError -from graphon.variables import ( - ArrayFileSegment, - ArraySegment, - FileSegment, - NoneSegment, - ObjectSegment, - StringSegment, -) - -from . import llm_utils -from .entities import ( - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - LLMNodeData, -) -from .exc import ( - InvalidContextStructureError, - InvalidVariableTypeError, - LLMNodeError, - MemoryRolePrefixRequiredError, - NoPromptFoundError, - TemplateTypeNotSupportError, - VariableNotFoundError, -) -from .file_saver import LLMFileSaver - -if TYPE_CHECKING: - from graphon.file.models import File - from graphon.runtime import GraphRuntimeState - -logger = logging.getLogger(__name__) - - -class LLMNode(Node[LLMNodeData]): - node_type = BuiltinNodeTypes.LLM - - # Compiled regex for extracting blocks (with compatibility for attributes) - _THINK_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) - - # Instance attributes specific to LLMNode. - # Output variable for file - _file_outputs: list[File] - - _llm_file_saver: LLMFileSaver - _retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None - _prompt_message_serializer: PromptMessageSerializerProtocol - _jinja2_template_renderer: Jinja2TemplateRenderer | None - _model_instance: PreparedLLMProtocol - _memory: PromptMessageMemory | None - _default_query_selector: tuple[str, ...] | None - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - *, - credentials_provider: object | None = None, - model_factory: object | None = None, - model_instance: PreparedLLMProtocol, - http_client: HttpClientProtocol, - memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver, - prompt_message_serializer: PromptMessageSerializerProtocol, - retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None = None, - jinja2_template_renderer: Jinja2TemplateRenderer | None = None, - default_query_selector: Sequence[str] | None = None, - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - # LLM file outputs, used for MultiModal outputs. - self._file_outputs = [] - - _ = credentials_provider, model_factory, http_client - self._model_instance = model_instance - self._memory = memory - - self._llm_file_saver = llm_file_saver - self._prompt_message_serializer = prompt_message_serializer - self._retriever_attachment_loader = retriever_attachment_loader - self._jinja2_template_renderer = jinja2_template_renderer - self._default_query_selector = tuple(default_query_selector) if default_query_selector is not None else None - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator: - node_inputs: dict[str, Any] = {} - process_data: dict[str, Any] = {} - result_text = "" - clean_text = "" - usage = LLMUsage.empty_usage() - finish_reason = None - reasoning_content = None - variable_pool = self.graph_runtime_state.variable_pool - - try: - # init messages template - self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) - - # fetch variables and fetch values from variable pool - inputs = self._fetch_inputs(node_data=self.node_data) - - # fetch jinja2 inputs - jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) - - # merge inputs - inputs.update(jinja_inputs) - - # fetch files - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=self.node_data.vision.configs.variable_selector, - ) - if self.node_data.vision.enabled - else [] - ) - - if files: - node_inputs["#files#"] = [file.to_dict() for file in files] - - # fetch context value - generator = self._fetch_context(node_data=self.node_data) - context = None - context_files: list[File] = [] - if generator is not None: - for event in generator: - context = event.context - context_files = event.context_files or [] - yield event - if context: - node_inputs["#context#"] = context - - if context_files: - node_inputs["#context_files#"] = [file.model_dump() for file in context_files] - - # fetch model config - model_instance = self._model_instance - # Resolve variable references in string-typed completion params - model_instance.parameters = llm_utils.resolve_completion_params_variables( - model_instance.parameters, variable_pool - ) - model_name = model_instance.model_name - model_provider = model_instance.provider - model_stop = model_instance.stop - - memory = self._memory - - query: str | None = None - if self.node_data.memory: - query = self.node_data.memory.query_prompt_template - if ( - not query - and self._default_query_selector - and (query_variable := variable_pool.get(self._default_query_selector)) - ): - query = query_variable.text - - prompt_messages, stop = LLMNode.fetch_prompt_messages( - sys_query=query, - sys_files=files, - context=context or "", - memory=memory, - model_instance=model_instance, - stop=model_stop, - prompt_template=self.node_data.prompt_template, - memory_config=self.node_data.memory, - vision_enabled=self.node_data.vision.enabled, - vision_detail=self.node_data.vision.configs.detail, - variable_pool=variable_pool, - jinja2_variables=self.node_data.prompt_config.jinja2_variables, - context_files=context_files, - jinja2_template_renderer=self._jinja2_template_renderer, - ) - - # handle invoke result - generator = LLMNode.invoke_llm( - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - structured_output_enabled=self.node_data.structured_output_enabled, - structured_output=self.node_data.structured_output, - file_saver=self._llm_file_saver, - file_outputs=self._file_outputs, - node_id=self._node_id, - node_type=self.node_type, - reasoning_format=self.node_data.reasoning_format, - ) - - structured_output: LLMStructuredOutput | None = None - - for event in generator: - if isinstance(event, StreamChunkEvent): - yield event - elif isinstance(event, ModelInvokeCompletedEvent): - # Raw text - result_text = event.text - usage = event.usage - finish_reason = event.finish_reason - reasoning_content = event.reasoning_content or "" - - # For downstream nodes, determine clean text based on reasoning_format - if self.node_data.reasoning_format == "tagged": - # Keep tags for backward compatibility - clean_text = result_text - else: - # Extract clean text from tags - clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format) - - # Process structured output if available from the event. - structured_output = ( - LLMStructuredOutput(structured_output=event.structured_output) - if event.structured_output - else None - ) - - break - elif isinstance(event, LLMStructuredOutput): - structured_output = event - - process_data = { - "model_mode": self.node_data.model.mode, - "prompts": self._prompt_message_serializer.serialize( - model_mode=self.node_data.model.mode, prompt_messages=prompt_messages - ), - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - "model_provider": model_provider, - "model_name": model_name, - } - - outputs = { - "text": clean_text, - "reasoning_content": reasoning_content, - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - } - if structured_output: - outputs["structured_output"] = structured_output.structured_output - if self._file_outputs: - outputs["files"] = ArrayFileSegment(value=self._file_outputs) - - # Send final chunk event to indicate streaming is complete - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - process_data=process_data, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - ) - except ValueError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=node_inputs, - process_data=process_data, - error_type=type(e).__name__, - llm_usage=usage, - ) - ) - except Exception as e: - logger.exception("error while executing llm node") - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=node_inputs, - process_data=process_data, - error_type=type(e).__name__, - llm_usage=usage, - ) - ) - - @staticmethod - def invoke_llm( - *, - model_instance: PreparedLLMProtocol, - prompt_messages: Sequence[PromptMessage], - stop: Sequence[str] | None = None, - structured_output_enabled: bool, - structured_output: Mapping[str, Any] | None = None, - file_saver: LLMFileSaver, - file_outputs: list[File], - node_id: str, - node_type: NodeType, - reasoning_format: Literal["separated", "tagged"] = "tagged", - ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: - model_parameters = model_instance.parameters - invoke_model_parameters = dict(model_parameters) - invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None] - if structured_output_enabled: - output_schema = LLMNode.fetch_structured_output_schema( - structured_output=structured_output or {}, - ) - request_start_time = time.perf_counter() - - invoke_result = cast( - LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], - model_instance.invoke_llm_with_structured_output( - prompt_messages=prompt_messages, - json_schema=output_schema, - model_parameters=invoke_model_parameters, - stop=stop, - stream=True, - ), - ) - else: - request_start_time = time.perf_counter() - - invoke_result = cast( - LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], - model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=invoke_model_parameters, - tools=None, - stop=stop, - stream=True, - ), - ) - - return LLMNode.handle_invoke_result( - invoke_result=invoke_result, - file_saver=file_saver, - file_outputs=file_outputs, - node_id=node_id, - node_type=node_type, - model_instance=model_instance, - reasoning_format=reasoning_format, - request_start_time=request_start_time, - ) - - @staticmethod - def handle_invoke_result( - *, - invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], - file_saver: LLMFileSaver, - file_outputs: list[File], - node_id: str, - node_type: NodeType, - model_instance: PreparedLLMProtocol | object, - reasoning_format: Literal["separated", "tagged"] = "tagged", - request_start_time: float | None = None, - ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: - # For blocking mode - if isinstance(invoke_result, LLMResult): - duration = None - if request_start_time is not None: - duration = time.perf_counter() - request_start_time - invoke_result.usage.latency = round(duration, 3) - event = LLMNode.handle_blocking_result( - invoke_result=invoke_result, - saver=file_saver, - file_outputs=file_outputs, - reasoning_format=reasoning_format, - request_latency=duration, - ) - yield event - return - - # For streaming mode - model = "" - prompt_messages: list[PromptMessage] = [] - - usage = LLMUsage.empty_usage() - finish_reason = None - full_text_buffer = io.StringIO() - - # Initialize streaming metrics tracking - start_time = request_start_time if request_start_time is not None else time.perf_counter() - first_token_time = None - has_content = False - - collected_structured_output = None # Collect structured_output from streaming chunks - # Consume the invoke result and handle generator exception - try: - for result in invoke_result: - if isinstance(result, LLMResultChunkWithStructuredOutput): - # Collect structured_output from the chunk - if result.structured_output is not None: - collected_structured_output = dict(result.structured_output) - yield result - if isinstance(result, LLMResultChunk): - contents = result.delta.message.content - for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( - contents=contents, - file_saver=file_saver, - file_outputs=file_outputs, - ): - # Detect first token for TTFT calculation - if text_part and not has_content: - first_token_time = time.perf_counter() - has_content = True - - full_text_buffer.write(text_part) - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=text_part, - is_final=False, - ) - - # Update the whole metadata - if not model and result.model: - model = result.model - if len(prompt_messages) == 0: - # TODO(QuantumGhost): it seems that this update has no visable effect. - # What's the purpose of the line below? - prompt_messages = list(result.prompt_messages) - if usage.prompt_tokens == 0 and result.delta.usage: - usage = result.delta.usage - if finish_reason is None and result.delta.finish_reason: - finish_reason = result.delta.finish_reason - except Exception as e: - if hasattr(model_instance, "is_structured_output_parse_error") and cast( - PreparedLLMProtocol, model_instance - ).is_structured_output_parse_error(e): - raise LLMNodeError(f"Failed to parse structured output: {e}") from e - if type(e).__name__ == "OutputParserError": - raise LLMNodeError(f"Failed to parse structured output: {e}") from e - raise - - # Extract reasoning content from tags in the main text - full_text = full_text_buffer.getvalue() - - if reasoning_format == "tagged": - # Keep tags in text for backward compatibility - clean_text = full_text - reasoning_content = "" - else: - # Extract clean text and reasoning from tags - clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) - - # Calculate streaming metrics - end_time = time.perf_counter() - total_duration = end_time - start_time - usage.latency = round(total_duration, 3) - if has_content and first_token_time: - gen_ai_server_time_to_first_token = first_token_time - start_time - llm_streaming_time_to_generate = end_time - first_token_time - usage.time_to_first_token = round(gen_ai_server_time_to_first_token, 3) - usage.time_to_generate = round(llm_streaming_time_to_generate, 3) - - yield ModelInvokeCompletedEvent( - # Use clean_text for separated mode, full_text for tagged mode - text=clean_text if reasoning_format == "separated" else full_text, - usage=usage, - finish_reason=finish_reason, - # Reasoning content for workflow variables and downstream nodes - reasoning_content=reasoning_content, - # Pass structured output if collected from streaming chunks - structured_output=collected_structured_output, - ) - - @staticmethod - def _image_file_to_markdown(file: File, /): - text_chunk = f"![]({file.generate_url()})" - return text_chunk - - @classmethod - def _split_reasoning( - cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged" - ) -> tuple[str, str]: - """ - Split reasoning content from text based on reasoning_format strategy. - - Args: - text: Full text that may contain blocks - reasoning_format: Strategy for handling reasoning content - - "separated": Remove tags and return clean text + reasoning_content field - - "tagged": Keep tags in text, return empty reasoning_content - - Returns: - tuple of (clean_text, reasoning_content) - """ - - if reasoning_format == "tagged": - return text, "" - - # Find all ... blocks (case-insensitive) - matches = cls._THINK_PATTERN.findall(text) - - # Extract reasoning content from all blocks - reasoning_content = "\n".join(match.strip() for match in matches) if matches else "" - - # Remove all ... blocks from original text - clean_text = cls._THINK_PATTERN.sub("", text) - - # Clean up extra whitespace - clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip() - - # Separated mode: always return clean text and reasoning_content - return clean_text, reasoning_content or "" - - def _transform_chat_messages( - self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / - ) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: - if isinstance(messages, LLMNodeCompletionModelPromptTemplate): - if messages.edition_type == "jinja2" and messages.jinja2_text: - messages.text = messages.jinja2_text - - return messages - - for message in messages: - if message.edition_type == "jinja2" and message.jinja2_text: - message.text = message.jinja2_text - - return messages - - def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: - variables: dict[str, Any] = {} - - if not node_data.prompt_config: - return variables - - for variable_selector in node_data.prompt_config.jinja2_variables or []: - variable_name = variable_selector.variable - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if variable is None: - raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") - - def parse_dict(input_dict: Mapping[str, Any]) -> str: - """ - Parse dict into string - """ - # check if it's a context structure - if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict: - return str(input_dict["content"]) - - # else, parse the dict - try: - return json.dumps(input_dict, ensure_ascii=False) - except Exception: - return str(input_dict) - - if isinstance(variable, ArraySegment): - result = "" - for item in variable.value: - if isinstance(item, dict): - result += parse_dict(item) - else: - result += str(item) - result += "\n" - value = result.strip() - elif isinstance(variable, ObjectSegment): - value = parse_dict(variable.value) - else: - value = variable.text - - variables[variable_name] = value - - return variables - - def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]: - inputs = {} - prompt_template = node_data.prompt_template - - variable_selectors = [] - if isinstance(prompt_template, list): - for prompt in prompt_template: - variable_template_parser = VariableTemplateParser(template=prompt.text) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - elif isinstance(prompt_template, CompletionModelPromptTemplate): - variable_template_parser = VariableTemplateParser(template=prompt_template.text) - variable_selectors = variable_template_parser.extract_variable_selectors() - - for variable_selector in variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if variable is None: - raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") - if isinstance(variable, NoneSegment): - inputs[variable_selector.variable] = "" - inputs[variable_selector.variable] = variable.to_object() - - memory = node_data.memory - if memory and memory.query_prompt_template: - query_variable_selectors = VariableTemplateParser( - template=memory.query_prompt_template - ).extract_variable_selectors() - for variable_selector in query_variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if variable is None: - raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") - if isinstance(variable, NoneSegment): - continue - inputs[variable_selector.variable] = variable.to_object() - - return inputs - - def _fetch_context(self, node_data: LLMNodeData): - if not node_data.context.enabled: - return - - if not node_data.context.variable_selector: - return - - context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector) - if context_value_variable: - if isinstance(context_value_variable, StringSegment): - yield RunRetrieverResourceEvent( - retriever_resources=[], context=context_value_variable.value, context_files=[] - ) - elif isinstance(context_value_variable, ArraySegment): - context_str = "" - original_retriever_resource: list[dict[str, Any]] = [] - context_files: list[File] = [] - for item in context_value_variable.value: - if isinstance(item, str): - context_str += item + "\n" - else: - if "content" not in item: - raise InvalidContextStructureError(f"Invalid context structure: {item}") - - if item.get("summary"): - context_str += item["summary"] + "\n" - context_str += item["content"] + "\n" - - retriever_resource = self._convert_to_original_retriever_resource(item) - if retriever_resource: - original_retriever_resource.append(retriever_resource) - segment_id = retriever_resource.get("segment_id") - if not segment_id: - continue - if self._retriever_attachment_loader is not None: - context_files.extend(self._retriever_attachment_loader.load(segment_id=segment_id)) - yield RunRetrieverResourceEvent( - retriever_resources=original_retriever_resource, - context=context_str.strip(), - context_files=context_files, - ) - - def _convert_to_original_retriever_resource(self, context_dict: dict) -> dict[str, Any] | None: - if ( - "metadata" in context_dict - and "_source" in context_dict["metadata"] - and context_dict["metadata"]["_source"] == "knowledge" - ): - metadata = context_dict.get("metadata", {}) - - return { - "position": metadata.get("position"), - "dataset_id": metadata.get("dataset_id"), - "dataset_name": metadata.get("dataset_name"), - "document_id": metadata.get("document_id"), - "document_name": metadata.get("document_name"), - "data_source_type": metadata.get("data_source_type"), - "segment_id": metadata.get("segment_id"), - "retriever_from": metadata.get("retriever_from"), - "score": metadata.get("score"), - "hit_count": metadata.get("segment_hit_count"), - "word_count": metadata.get("segment_word_count"), - "segment_position": metadata.get("segment_position"), - "index_node_hash": metadata.get("segment_index_node_hash"), - "content": context_dict.get("content"), - "page": metadata.get("page"), - "doc_metadata": metadata.get("doc_metadata"), - "files": context_dict.get("files"), - "summary": context_dict.get("summary"), - } - - return None - - @staticmethod - def fetch_prompt_messages( - *, - sys_query: str | None = None, - sys_files: Sequence[File], - context: str = "", - memory: PromptMessageMemory | None = None, - model_instance: PreparedLLMProtocol, - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, - stop: Sequence[str] | None = None, - memory_config: MemoryConfig | None = None, - vision_enabled: bool = False, - vision_detail: ImagePromptMessageContent.DETAIL, - variable_pool: VariablePool, - jinja2_variables: Sequence[VariableSelector], - context_files: list[File] | None = None, - jinja2_template_renderer: Jinja2TemplateRenderer | None = None, - ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - prompt_messages: list[PromptMessage] = [] - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - - if isinstance(prompt_template, list): - # For chat model - prompt_messages.extend( - LLMNode.handle_list_messages( - messages=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail, - jinja2_template_renderer=jinja2_template_renderer, - ) - ) - - # Get memory messages for chat mode - memory_messages = _handle_memory_chat_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - # Extend prompt_messages with memory messages - prompt_messages.extend(memory_messages) - - # Add current query to the prompt messages - if sys_query: - message = LLMNodeChatModelMessage( - text=sys_query, - role=PromptMessageRole.USER, - edition_type="basic", - ) - prompt_messages.extend( - LLMNode.handle_list_messages( - messages=[message], - context="", - jinja2_variables=[], - variable_pool=variable_pool, - vision_detail_config=vision_detail, - jinja2_template_renderer=jinja2_template_renderer, - ) - ) - - elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - # For completion model - prompt_messages.extend( - _handle_completion_template( - template=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - jinja2_template_renderer=jinja2_template_renderer, - ) - ) - - # Get memory text for completion model - memory_text = _handle_memory_completion_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - # Insert histories into the prompt - prompt_content = prompt_messages[0].content - # For issue #11247 - Check if prompt content is a string or a list - if isinstance(prompt_content, str): - prompt_content = str(prompt_content) - if "#histories#" in prompt_content: - prompt_content = prompt_content.replace("#histories#", memory_text) - else: - prompt_content = memory_text + "\n" + prompt_content - prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - if "#histories#" in content_item.data: - content_item.data = content_item.data.replace("#histories#", memory_text) - else: - content_item.data = memory_text + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - - # Add current query to the prompt message - if sys_query: - if isinstance(prompt_content, str): - prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) - prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - content_item.data = sys_query + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - else: - raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) - - # The sys_files will be deprecated later - if vision_enabled and sys_files: - file_prompts = [] - for file in sys_files: - file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) - file_prompts.append(file_prompt) - # If last prompt is a user prompt, add files into its contents, - # otherwise append a new user prompt - if ( - len(prompt_messages) > 0 - and isinstance(prompt_messages[-1], UserPromptMessage) - and isinstance(prompt_messages[-1].content, list) - ): - prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) - else: - prompt_messages.append(UserPromptMessage(content=file_prompts)) - - # The context_files - if vision_enabled and context_files: - file_prompts = [] - for file in context_files: - file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) - file_prompts.append(file_prompt) - # If last prompt is a user prompt, add files into its contents, - # otherwise append a new user prompt - if ( - len(prompt_messages) > 0 - and isinstance(prompt_messages[-1], UserPromptMessage) - and isinstance(prompt_messages[-1].content, list) - ): - prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) - else: - prompt_messages.append(UserPromptMessage(content=file_prompts)) - - # Remove empty messages and filter unsupported content - filtered_prompt_messages = [] - for prompt_message in prompt_messages: - if isinstance(prompt_message.content, list): - prompt_message_content: list[PromptMessageContentUnionTypes] = [] - for content_item in prompt_message.content: - # Skip content if features are not defined - if not model_schema.features: - if content_item.type != PromptMessageContentType.TEXT: - continue - prompt_message_content.append(content_item) - continue - - # Skip content if corresponding feature is not supported - if ( - ( - content_item.type == PromptMessageContentType.IMAGE - and ModelFeature.VISION not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.DOCUMENT - and ModelFeature.DOCUMENT not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.VIDEO - and ModelFeature.VIDEO not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.AUDIO - and ModelFeature.AUDIO not in model_schema.features - ) - ): - continue - prompt_message_content.append(content_item) - if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: - prompt_message.content = prompt_message_content[0].data - else: - prompt_message.content = prompt_message_content - if prompt_message.is_empty(): - continue - filtered_prompt_messages.append(prompt_message) - - if len(filtered_prompt_messages) == 0: - raise NoPromptFoundError( - "No prompt found in the LLM configuration. " - "Please ensure a prompt is properly configured before proceeding." - ) - - return filtered_prompt_messages, stop - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: LLMNodeData, - ) -> Mapping[str, Sequence[str]]: - # graph_config is not used in this node type - _ = graph_config # Explicitly mark as unused - prompt_template = node_data.prompt_template - variable_selectors = [] - if isinstance(prompt_template, list): - for prompt in prompt_template: - if prompt.edition_type != "jinja2": - variable_template_parser = VariableTemplateParser(template=prompt.text) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - if prompt_template.edition_type != "jinja2": - variable_template_parser = VariableTemplateParser(template=prompt_template.text) - variable_selectors = variable_template_parser.extract_variable_selectors() - else: - raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") - - variable_mapping: dict[str, Any] = {} - for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - memory = node_data.memory - if memory and memory.query_prompt_template: - query_variable_selectors = VariableTemplateParser( - template=memory.query_prompt_template - ).extract_variable_selectors() - for variable_selector in query_variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - if node_data.context.enabled: - variable_mapping["#context#"] = node_data.context.variable_selector - - if node_data.vision.enabled: - variable_mapping["#files#"] = node_data.vision.configs.variable_selector - - if node_data.prompt_config: - enable_jinja = False - - if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - if prompt_template.edition_type == "jinja2": - enable_jinja = True - else: - for prompt in prompt_template: - if prompt.edition_type == "jinja2": - enable_jinja = True - break - - if enable_jinja: - for variable_selector in node_data.prompt_config.jinja2_variables or []: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} - - return variable_mapping - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return { - "type": "llm", - "config": { - "prompt_templates": { - "chat_model": { - "prompts": [ - {"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"} - ] - }, - "completion_model": { - "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, - "prompt": { - "text": "Here are the chat histories between human and assistant, inside " - " XML tags.\n\n\n{{" - "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", - "edition_type": "basic", - }, - "stop": ["Human:"], - }, - } - }, - } - - @staticmethod - def handle_list_messages( - *, - messages: Sequence[LLMNodeChatModelMessage], - context: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - vision_detail_config: ImagePromptMessageContent.DETAIL, - jinja2_template_renderer: Jinja2TemplateRenderer | None = None, - ) -> Sequence[PromptMessage]: - prompt_messages: list[PromptMessage] = [] - for message in messages: - if message.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=message.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - jinja2_template_renderer=jinja2_template_renderer, - ) - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], role=message.role - ) - prompt_messages.append(prompt_message) - else: - # Get segment group from basic message - template = message.text.replace(llm_utils.CONTEXT_PLACEHOLDER, context) - segment_group = variable_pool.convert_template(template) - - # Process segments for images - file_contents = [] - for segment in segment_group.value: - if isinstance(segment, ArrayFileSegment): - for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - elif isinstance(segment, FileSegment): - file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - - # Create message with text from all segments - plain_text = segment_group.text - if plain_text: - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=plain_text)], role=message.role - ) - prompt_messages.append(prompt_message) - - if file_contents: - # Create message with image contents - prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) - prompt_messages.append(prompt_message) - - return prompt_messages - - @staticmethod - def handle_blocking_result( - *, - invoke_result: LLMResult | LLMResultWithStructuredOutput, - saver: LLMFileSaver, - file_outputs: list[File], - reasoning_format: Literal["separated", "tagged"] = "tagged", - request_latency: float | None = None, - ) -> ModelInvokeCompletedEvent: - buffer = io.StringIO() - for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( - contents=invoke_result.message.content, - file_saver=saver, - file_outputs=file_outputs, - ): - buffer.write(text_part) - - # Extract reasoning content from tags in the main text - full_text = buffer.getvalue() - - if reasoning_format == "tagged": - # Keep tags in text for backward compatibility - clean_text = full_text - reasoning_content = "" - else: - # Extract clean text and reasoning from tags - clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) - - event = ModelInvokeCompletedEvent( - # Use clean_text for separated mode, full_text for tagged mode - text=clean_text if reasoning_format == "separated" else full_text, - usage=invoke_result.usage, - finish_reason=None, - # Reasoning content for workflow variables and downstream nodes - reasoning_content=reasoning_content, - # Pass structured output if enabled - structured_output=getattr(invoke_result, "structured_output", None), - ) - if request_latency is not None: - event.usage.latency = round(request_latency, 3) - return event - - @staticmethod - def save_multimodal_image_output( - *, - content: ImagePromptMessageContent, - file_saver: LLMFileSaver, - ) -> File: - """_save_multimodal_output saves multi-modal contents generated by LLM plugins. - - There are two kinds of multimodal outputs: - - - Inlined data encoded in base64, which would be saved to storage directly. - - Remote files referenced by an url, which would be downloaded and then saved to storage. - - Currently, only image files are supported. - """ - if content.url != "": - saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE) - else: - saved_file = file_saver.save_binary_string( - data=base64.b64decode(content.base64_data), - mime_type=content.mime_type, - file_type=FileType.IMAGE, - ) - return saved_file - - @staticmethod - def fetch_structured_output_schema( - *, - structured_output: Mapping[str, Any], - ) -> dict[str, Any]: - """ - Fetch the structured output schema from the node data. - - Returns: - dict[str, Any]: The structured output schema - """ - if not structured_output: - raise LLMNodeError("Please provide a valid structured output schema") - structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False) - if not structured_output_schema: - raise LLMNodeError("Please provide a valid structured output schema") - - try: - schema = json.loads(structured_output_schema) - if not isinstance(schema, dict): - raise LLMNodeError("structured_output_schema must be a JSON object") - return schema - except json.JSONDecodeError: - raise LLMNodeError("structured_output_schema is not valid JSON format") - - @staticmethod - def _save_multimodal_output_and_convert_result_to_markdown( - *, - contents: str | list[PromptMessageContentUnionTypes] | None, - file_saver: LLMFileSaver, - file_outputs: list[File], - ) -> Generator[str, None, None]: - """Convert intermediate prompt messages into strings and yield them to the caller. - - If the messages contain non-textual content (e.g., multimedia like images or videos), - it will be saved separately, and the corresponding Markdown representation will - be yielded to the caller. - """ - - # NOTE(QuantumGhost): This function should yield results to the caller immediately - # whenever new content or partial content is available. Avoid any intermediate buffering - # of results. Additionally, do not yield empty strings; instead, yield from an empty list - # if necessary. - if contents is None: - yield from [] - return - if isinstance(contents, str): - yield contents - else: - for item in contents: - if isinstance(item, TextPromptMessageContent): - yield item.data - elif isinstance(item, ImagePromptMessageContent): - file = LLMNode.save_multimodal_image_output( - content=item, - file_saver=file_saver, - ) - file_outputs.append(file) - yield LLMNode._image_file_to_markdown(file) - else: - logger.warning("unknown item type encountered, type=%s", type(item)) - yield str(item) - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled - - @property - def model_instance(self) -> PreparedLLMProtocol: - return self._model_instance - - -def _combine_message_content_with_role( - *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole -): - match role: - case PromptMessageRole.USER: - return UserPromptMessage(content=contents) - case PromptMessageRole.ASSISTANT: - return AssistantPromptMessage(content=contents) - case PromptMessageRole.SYSTEM: - return SystemPromptMessage(content=contents) - case _: - raise NotImplementedError(f"Role {role} is not supported") - - -def _render_jinja2_message( - *, - template: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - jinja2_template_renderer: Jinja2TemplateRenderer | None, -): - if not template: - return "" - - jinja2_inputs = {} - for jinja2_variable in jinja2_variables: - variable = variable_pool.get(jinja2_variable.value_selector) - jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" - if jinja2_template_renderer is None: - raise TemplateRenderError("LLMNode requires an injected jinja2_template_renderer for jinja2 prompts.") - return jinja2_template_renderer.render_template(template, jinja2_inputs) - - -def _calculate_rest_token( - *, - prompt_messages: list[PromptMessage], - model_instance: PreparedLLMProtocol, -) -> int: - rest_tokens = 2000 - runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - runtime_model_parameters = model_instance.parameters - - model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in runtime_model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - runtime_model_parameters.get(parameter_rule.name) - or runtime_model_parameters.get(str(parameter_rule.use_template)) - or 0 - ) - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - -def _handle_memory_chat_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: PreparedLLMProtocol, -) -> Sequence[PromptMessage]: - memory_messages: Sequence[PromptMessage] = [] - # Get messages from memory for chat model - if memory and memory_config: - rest_tokens = _calculate_rest_token( - prompt_messages=[], - model_instance=model_instance, - ) - memory_messages = memory.get_history_prompt_messages( - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - ) - return memory_messages - - -def _handle_memory_completion_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: PreparedLLMProtocol, -) -> str: - memory_text = "" - # Get history text from memory for completion model - if memory and memory_config: - rest_tokens = _calculate_rest_token( - prompt_messages=[], - model_instance=model_instance, - ) - if not memory_config.role_prefix: - raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") - memory_text = llm_utils.fetch_memory_text( - memory=memory, - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - human_prefix=memory_config.role_prefix.user, - ai_prefix=memory_config.role_prefix.assistant, - ) - return memory_text - - -def _handle_completion_template( - *, - template: LLMNodeCompletionModelPromptTemplate, - context: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - jinja2_template_renderer: Jinja2TemplateRenderer | None = None, -) -> Sequence[PromptMessage]: - """Handle completion template processing outside of LLMNode class. - - Args: - template: The completion model prompt template - context: Context string - jinja2_variables: Variables for jinja2 template rendering - variable_pool: Variable pool for template conversion - - Returns: - Sequence of prompt messages - """ - prompt_messages = [] - if template.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=template.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - jinja2_template_renderer=jinja2_template_renderer, - ) - else: - template_text = template.text.replace(llm_utils.CONTEXT_PLACEHOLDER, context) - result_text = variable_pool.convert_template(template_text).text - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER - ) - prompt_messages.append(prompt_message) - return prompt_messages diff --git a/api/graphon/nodes/llm/protocols.py b/api/graphon/nodes/llm/protocols.py deleted file mode 100644 index 65bfd533d1..0000000000 --- a/api/graphon/nodes/llm/protocols.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - -from typing import Any, Protocol - -from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol - - -class CredentialsProvider(Protocol): - """Port for loading runtime credentials for a provider/model pair.""" - - def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: - """Return credentials for the target provider/model or raise a domain error.""" - ... - - -class ModelFactory(Protocol): - """Port for creating prepared graph-facing LLM runtimes for execution.""" - - def init_model_instance(self, provider_name: str, model_name: str) -> PreparedLLMProtocol: - """Create a prepared LLM runtime that is ready for graph execution.""" - ... diff --git a/api/graphon/nodes/llm/runtime_protocols.py b/api/graphon/nodes/llm/runtime_protocols.py deleted file mode 100644 index dbe415d363..0000000000 --- a/api/graphon/nodes/llm/runtime_protocols.py +++ /dev/null @@ -1,77 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator, Mapping, Sequence -from typing import Any, Protocol - -from graphon.file import File -from graphon.model_runtime.entities import LLMMode, PromptMessage -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from graphon.model_runtime.entities.message_entities import PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity - - -class PreparedLLMProtocol(Protocol): - """A graph-facing LLM runtime with provider-specific setup already applied.""" - - @property - def provider(self) -> str: ... - - @property - def model_name(self) -> str: ... - - @property - def parameters(self) -> Mapping[str, Any]: ... - - @parameters.setter - def parameters(self, value: Mapping[str, Any]) -> None: ... - - @property - def stop(self) -> Sequence[str] | None: ... - - def get_model_schema(self) -> AIModelEntity: ... - - def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: ... - - def invoke_llm( - self, - *, - prompt_messages: Sequence[PromptMessage], - model_parameters: Mapping[str, Any], - tools: Sequence[PromptMessageTool] | None, - stop: Sequence[str] | None, - stream: bool, - ) -> LLMResult | Generator[LLMResultChunk, None, None]: ... - - def invoke_llm_with_structured_output( - self, - *, - prompt_messages: Sequence[PromptMessage], - json_schema: Mapping[str, Any], - model_parameters: Mapping[str, Any], - stop: Sequence[str] | None, - stream: bool, - ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... - - def is_structured_output_parse_error(self, error: Exception) -> bool: ... - - -class PromptMessageSerializerProtocol(Protocol): - """Port for converting compiled prompt messages into persisted process data.""" - - def serialize( - self, - *, - model_mode: LLMMode, - prompt_messages: Sequence[PromptMessage], - ) -> Any: ... - - -class RetrieverAttachmentLoaderProtocol(Protocol): - """Port for resolving retriever segment attachments into graph file references.""" - - def load(self, *, segment_id: str) -> Sequence[File]: ... diff --git a/api/graphon/nodes/loop/__init__.py b/api/graphon/nodes/loop/__init__.py deleted file mode 100644 index 9fe695607b..0000000000 --- a/api/graphon/nodes/loop/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .entities import LoopNodeData -from .loop_end_node import LoopEndNode -from .loop_node import LoopNode -from .loop_start_node import LoopStartNode - -__all__ = ["LoopEndNode", "LoopNode", "LoopNodeData", "LoopStartNode"] diff --git a/api/graphon/nodes/loop/entities.py b/api/graphon/nodes/loop/entities.py deleted file mode 100644 index e7362769e9..0000000000 --- a/api/graphon/nodes/loop/entities.py +++ /dev/null @@ -1,107 +0,0 @@ -from enum import StrEnum -from typing import Annotated, Any, Literal - -from pydantic import AfterValidator, BaseModel, Field, field_validator - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base import BaseLoopNodeData, BaseLoopState -from graphon.utils.condition.entities import Condition -from graphon.variables.types import SegmentType - -_VALID_VAR_TYPE = frozenset( - [ - SegmentType.STRING, - SegmentType.NUMBER, - SegmentType.OBJECT, - SegmentType.BOOLEAN, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_BOOLEAN, - ] -) - - -def _is_valid_var_type(seg_type: SegmentType) -> SegmentType: - if seg_type not in _VALID_VAR_TYPE: - raise ValueError(...) - return seg_type - - -class LoopVariableData(BaseModel): - """ - Loop Variable Data. - """ - - label: str - var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] - value_type: Literal["variable", "constant"] - value: Any | list[str] | None = None - - -class LoopNodeData(BaseLoopNodeData): - type: NodeType = BuiltinNodeTypes.LOOP - loop_count: int # Maximum number of loops - break_conditions: list[Condition] # Conditions to break the loop - logical_operator: Literal["and", "or"] - loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData]) - outputs: dict[str, Any] = Field(default_factory=dict) - - @field_validator("outputs", mode="before") - @classmethod - def validate_outputs(cls, v): - if v is None: - return {} - return v - - -class LoopStartNodeData(BaseNodeData): - """ - Loop Start Node Data. - """ - - type: NodeType = BuiltinNodeTypes.LOOP_START - - -class LoopEndNodeData(BaseNodeData): - """ - Loop End Node Data. - """ - - type: NodeType = BuiltinNodeTypes.LOOP_END - - -class LoopState(BaseLoopState): - """ - Loop State. - """ - - outputs: list[Any] = Field(default_factory=list) - current_output: Any = None - - class MetaData(BaseLoopState.MetaData): - """ - Data. - """ - - loop_length: int - - def get_last_output(self) -> Any: - """ - Get last output. - """ - if self.outputs: - return self.outputs[-1] - return None - - def get_current_output(self) -> Any: - """ - Get current output. - """ - return self.current_output - - -class LoopCompletedReason(StrEnum): - LOOP_BREAK = "loop_break" - LOOP_COMPLETED = "loop_completed" diff --git a/api/graphon/nodes/loop/loop_end_node.py b/api/graphon/nodes/loop/loop_end_node.py deleted file mode 100644 index c0562b59c4..0000000000 --- a/api/graphon/nodes/loop/loop_end_node.py +++ /dev/null @@ -1,22 +0,0 @@ -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.loop.entities import LoopEndNodeData - - -class LoopEndNode(Node[LoopEndNodeData]): - """ - Loop End Node. - """ - - node_type = BuiltinNodeTypes.LOOP_END - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run the node. - """ - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/graphon/nodes/loop/loop_node.py b/api/graphon/nodes/loop/loop_node.py deleted file mode 100644 index d574e9f7ae..0000000000 --- a/api/graphon/nodes/loop/loop_node.py +++ /dev/null @@ -1,428 +0,0 @@ -import contextlib -import json -import logging -from collections.abc import Callable, Generator, Mapping, Sequence -from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, Literal, cast - -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.enums import ( - BuiltinNodeTypes, - NodeExecutionType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.graph_events import ( - GraphNodeEventBase, - GraphRunAbortedEvent, - GraphRunFailedEvent, - NodeRunSucceededEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import ( - LoopFailedEvent, - LoopNextEvent, - LoopStartedEvent, - LoopSucceededEvent, - NodeEventBase, - NodeRunResult, - StreamCompletedEvent, -) -from graphon.nodes.base import LLMUsageTrackingMixin -from graphon.nodes.base.node import Node -from graphon.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData -from graphon.utils.condition.processor import ConditionProcessor -from graphon.variables import Segment, SegmentType, TypeMismatchError, build_segment_with_type, segment_to_variable - -if TYPE_CHECKING: - from graphon.graph_engine import GraphEngine - -logger = logging.getLogger(__name__) -_DEFAULT_CHILD_ABORT_REASON = "child graph aborted" - - -class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): - """ - Loop Node. - """ - - node_type = BuiltinNodeTypes.LOOP - execution_type = NodeExecutionType.CONTAINER - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator: - """Run the node.""" - # Get inputs - loop_count = self.node_data.loop_count - break_conditions = self.node_data.break_conditions - logical_operator = self.node_data.logical_operator - - inputs = {"loop_count": loop_count} - - if not self.node_data.start_node_id: - raise ValueError(f"field start_node_id in loop {self._node_id} not found") - - root_node_id = self.node_data.start_node_id - - # Initialize loop variables in the original variable pool - loop_variable_selectors = {} - if self.node_data.loop_variables: - value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = { - "constant": lambda var: self._get_segment_for_constant(var.var_type, var.value), - "variable": lambda var: ( - self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None - ), - } - for loop_variable in self.node_data.loop_variables: - if loop_variable.value_type not in value_processor: - raise ValueError( - f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}" - ) - - processed_segment = value_processor[loop_variable.value_type](loop_variable) - if not processed_segment: - raise ValueError(f"Invalid value for loop variable {loop_variable.label}") - variable_selector = [self._node_id, loop_variable.label] - variable = segment_to_variable(segment=processed_segment, selector=variable_selector) - self.graph_runtime_state.variable_pool.add(variable_selector, variable.value) - loop_variable_selectors[loop_variable.label] = variable_selector - inputs[loop_variable.label] = processed_segment.value - - start_at = datetime.now(UTC).replace(tzinfo=None) - condition_processor = ConditionProcessor() - - loop_duration_map: dict[str, float] = {} - single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output - loop_usage = LLMUsage.empty_usage() - loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id) - - # Start Loop event - yield LoopStartedEvent( - start_at=start_at, - inputs=inputs, - metadata={"loop_length": loop_count}, - ) - - try: - reach_break_condition = False - if break_conditions: - with contextlib.suppress(ValueError): - _, _, reach_break_condition = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) - - if reach_break_condition: - loop_count = 0 - - for i in range(loop_count): - # Clear stale variables from previous loop iterations to avoid streaming old values - self._clear_loop_subgraph_variables(loop_node_ids) - graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id) - - loop_start_time = datetime.now(UTC).replace(tzinfo=None) - try: - reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i) - finally: - loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) - # Track loop duration - loop_duration_map[str(i)] = (datetime.now(UTC).replace(tzinfo=None) - loop_start_time).total_seconds() - - # Accumulate outputs from the sub-graph's response nodes - for key, value in graph_engine.graph_runtime_state.outputs.items(): - if key == "answer": - # Concatenate answer outputs with newline - existing_answer = self.graph_runtime_state.get_output("answer", "") - if existing_answer: - self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}") - else: - self.graph_runtime_state.set_output("answer", value) - else: - # For other outputs, just update - self.graph_runtime_state.set_output(key, value) - - # Collect loop variable values after iteration - single_loop_variable = {} - for key, selector in loop_variable_selectors.items(): - segment = self.graph_runtime_state.variable_pool.get(selector) - single_loop_variable[key] = segment.value if segment else None - - single_loop_variable_map[str(i)] = single_loop_variable - - if reach_break_node: - break - - if break_conditions: - _, _, reach_break_condition = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) - if reach_break_condition: - break - - yield LoopNextEvent( - index=i + 1, - pre_loop_output=self.node_data.outputs, - ) - - self._accumulate_usage(loop_usage) - # Loop completed successfully - yield LoopSucceededEvent( - start_at=start_at, - inputs=inputs, - outputs=self.node_data.outputs, - steps=loop_count, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - WorkflowNodeExecutionMetadataKey.COMPLETED_REASON: ( - LoopCompletedReason.LOOP_BREAK - if reach_break_condition - else LoopCompletedReason.LOOP_COMPLETED.value - ), - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - outputs=self.node_data.outputs, - inputs=inputs, - llm_usage=loop_usage, - ) - ) - - except Exception as e: - self._accumulate_usage(loop_usage) - yield LoopFailedEvent( - start_at=start_at, - inputs=inputs, - steps=loop_count, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - "completed_reason": "error", - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - error=str(e), - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - llm_usage=loop_usage, - ) - ) - - def _run_single_loop( - self, - *, - graph_engine: "GraphEngine", - current_index: int, - ) -> Generator[NodeEventBase | GraphNodeEventBase, None, bool]: - reach_break_node = False - for event in graph_engine.run(): - if isinstance(event, GraphNodeEventBase): - self._append_loop_info_to_event(event=event, loop_run_index=current_index) - - if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.LOOP_START: - continue - if isinstance(event, GraphNodeEventBase): - yield event - if isinstance(event, NodeRunSucceededEvent) and event.node_type == BuiltinNodeTypes.LOOP_END: - reach_break_node = True - if isinstance(event, GraphRunAbortedEvent): - raise RuntimeError(event.reason or _DEFAULT_CHILD_ABORT_REASON) - if isinstance(event, GraphRunFailedEvent): - raise Exception(event.error) - - for loop_var in self.node_data.loop_variables or []: - key, sel = loop_var.label, [self._node_id, loop_var.label] - segment = self.graph_runtime_state.variable_pool.get(sel) - self.node_data.outputs[key] = segment.value if segment else None - self.node_data.outputs["loop_round"] = current_index + 1 - - return reach_break_node - - def _append_loop_info_to_event( - self, - event: GraphNodeEventBase, - loop_run_index: int, - ): - event.in_loop_id = self._node_id - loop_metadata = { - WorkflowNodeExecutionMetadataKey.LOOP_ID: self._node_id, - WorkflowNodeExecutionMetadataKey.LOOP_INDEX: loop_run_index, - } - - current_metadata = event.node_run_result.metadata - if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata: - event.node_run_result.metadata = {**current_metadata, **loop_metadata} - - def _clear_loop_subgraph_variables(self, loop_node_ids: set[str]) -> None: - """ - Remove variables produced by loop sub-graph nodes from previous iterations. - - Keeping stale variables causes a freshly created response coordinator in the - next iteration to fall back to outdated values when no stream chunks exist. - """ - variable_pool = self.graph_runtime_state.variable_pool - for node_id in loop_node_ids: - variable_pool.remove([node_id]) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: LoopNodeData, - ) -> Mapping[str, Sequence[str]]: - variable_mapping = {} - - # Extract loop node IDs statically from graph_config - - loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id) - - # Get node configs from graph_config - node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} - for sub_node_id, sub_node_config in node_configs.items(): - if sub_node_config.get("data", {}).get("loop_id") != node_id: - continue - - # variable selector to variable mapping - try: - typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) - node_type = typed_sub_node_config["data"].type - node_mapping = Node.get_node_type_classes_mapping() - if node_type not in node_mapping: - continue - node_version = str(typed_sub_node_config["data"].version) - node_cls = node_mapping[node_type][node_version] - - sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=typed_sub_node_config - ) - sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) - except NotImplementedError: - sub_node_variable_mapping = {} - - # remove loop variables - sub_node_variable_mapping = { - sub_node_id + "." + key: value - for key, value in sub_node_variable_mapping.items() - if value[0] != node_id - } - - variable_mapping.update(sub_node_variable_mapping) - - for loop_variable in node_data.loop_variables or []: - if loop_variable.value_type == "variable": - assert loop_variable.value is not None, "Loop variable value must be provided for variable type" - # add loop variable to variable mapping - selector = loop_variable.value - variable_mapping[f"{node_id}.{loop_variable.label}"] = selector - - # remove variable out from loop - variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids} - - return variable_mapping - - @classmethod - def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]: - """ - Extract node IDs that belong to a specific loop from graph configuration. - - This method statically analyzes the graph configuration to find all nodes - that are part of the specified loop, without creating actual node instances. - - :param graph_config: the complete graph configuration - :param loop_node_id: the ID of the loop node - :return: set of node IDs that belong to the loop - """ - loop_node_ids = set() - - # Find all nodes that belong to this loop - nodes = graph_config.get("nodes", []) - for node in nodes: - node_data = node.get("data", {}) - if node_data.get("loop_id") == loop_node_id: - node_id = node.get("id") - if node_id: - loop_node_ids.add(node_id) - - return loop_node_ids - - @staticmethod - def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment: - """Get the appropriate segment type for a constant value.""" - # TODO: Refactor for maintainability: - # 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py) - # 2. Consider moving this method to LoopVariableData class for better encapsulation - if not var_type.is_array_type() or var_type == SegmentType.ARRAY_BOOLEAN: - value = original_value - elif var_type in [ - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_STRING, - ]: - if original_value and isinstance(original_value, str): - value = json.loads(original_value) - else: - logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type) - value = [] - else: - raise AssertionError("this statement should be unreachable.") - try: - return build_segment_with_type(var_type, value=value) - except TypeMismatchError as type_exc: - # Attempt to parse the value as a JSON-encoded string, if applicable. - if not isinstance(original_value, str): - raise - try: - value = json.loads(original_value) - except ValueError: - raise type_exc - return build_segment_with_type(var_type, value) - - def _create_graph_engine(self, start_at: datetime, root_node_id: str): - from graphon.entities import GraphInitParams - - # Create GraphInitParams for child graph execution. - graph_init_params = GraphInitParams( - workflow_id=self.workflow_id, - graph_config=self.graph_config, - run_context=self.run_context, - call_depth=self.workflow_call_depth, - ) - - return self.graph_runtime_state.create_child_engine( - workflow_id=self.workflow_id, - graph_init_params=graph_init_params, - root_node_id=root_node_id, - ) diff --git a/api/graphon/nodes/loop/loop_start_node.py b/api/graphon/nodes/loop/loop_start_node.py deleted file mode 100644 index 2b17054ae2..0000000000 --- a/api/graphon/nodes/loop/loop_start_node.py +++ /dev/null @@ -1,22 +0,0 @@ -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.loop.entities import LoopStartNodeData - - -class LoopStartNode(Node[LoopStartNodeData]): - """ - Loop Start Node. - """ - - node_type = BuiltinNodeTypes.LOOP_START - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run the node. - """ - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/graphon/nodes/parameter_extractor/__init__.py b/api/graphon/nodes/parameter_extractor/__init__.py deleted file mode 100644 index bdbf19a7d3..0000000000 --- a/api/graphon/nodes/parameter_extractor/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .parameter_extractor_node import ParameterExtractorNode - -__all__ = ["ParameterExtractorNode"] diff --git a/api/graphon/nodes/parameter_extractor/entities.py b/api/graphon/nodes/parameter_extractor/entities.py deleted file mode 100644 index 8fda1b9e79..0000000000 --- a/api/graphon/nodes/parameter_extractor/entities.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Annotated, Any, Literal - -from pydantic import ( - BaseModel, - BeforeValidator, - Field, - field_validator, -) - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.llm.entities import ModelConfig, VisionConfig -from graphon.prompt_entities import MemoryConfig -from graphon.variables.types import SegmentType - -_OLD_BOOL_TYPE_NAME = "bool" -_OLD_SELECT_TYPE_NAME = "select" - -_VALID_PARAMETER_TYPES = frozenset( - [ - SegmentType.STRING, # "string", - SegmentType.NUMBER, # "number", - SegmentType.BOOLEAN, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_BOOLEAN, - _OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node - _OLD_SELECT_TYPE_NAME, # string type with enumeration choices. - ] -) - - -def _validate_type(parameter_type: str) -> SegmentType: - if parameter_type not in _VALID_PARAMETER_TYPES: - raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.") - - if parameter_type == _OLD_BOOL_TYPE_NAME: - return SegmentType.BOOLEAN - elif parameter_type == _OLD_SELECT_TYPE_NAME: - return SegmentType.STRING - return SegmentType(parameter_type) - - -class ParameterConfig(BaseModel): - """ - Parameter Config. - """ - - name: str - type: Annotated[SegmentType, BeforeValidator(_validate_type)] - options: list[str] | None = None - description: str - required: bool - - @field_validator("name", mode="before") - @classmethod - def validate_name(cls, value) -> str: - if not value: - raise ValueError("Parameter name is required") - if value in {"__reason", "__is_success"}: - raise ValueError("Invalid parameter name, __reason and __is_success are reserved") - return str(value) - - def is_array_type(self) -> bool: - return self.type.is_array_type() - - def element_type(self) -> SegmentType: - """Return the element type of the parameter. - - Raises a ValueError if the parameter's type is not an array type. - """ - element_type = self.type.element_type() - # At this point, self.type is guaranteed to be one of `ARRAY_STRING`, - # `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`. - # - # See: _VALID_PARAMETER_TYPES for reference. - assert element_type is not None, f"the element type should not be None, {self.type=}" - return element_type - - -class ParameterExtractorNodeData(BaseNodeData): - """ - Parameter Extractor Node Data. - """ - - type: NodeType = BuiltinNodeTypes.PARAMETER_EXTRACTOR - model: ModelConfig - query: list[str] - parameters: list[ParameterConfig] - instruction: str | None = None - memory: MemoryConfig | None = None - reasoning_mode: Literal["function_call", "prompt"] - vision: VisionConfig = Field(default_factory=VisionConfig) - - @field_validator("reasoning_mode", mode="before") - @classmethod - def set_reasoning_mode(cls, v) -> str: - return v or "function_call" - - def get_parameter_json_schema(self): - """ - Get parameter json schema. - - :return: parameter json schema - """ - parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []} - - for parameter in self.parameters: - parameter_schema: dict[str, Any] = {"description": parameter.description} - - if parameter.type == SegmentType.STRING: - parameter_schema["type"] = "string" - elif parameter.type.is_array_type(): - parameter_schema["type"] = "array" - element_type = parameter.type.element_type() - if element_type is None: - raise AssertionError("element type should not be None.") - parameter_schema["items"] = {"type": element_type.value} - else: - parameter_schema["type"] = parameter.type - - if parameter.options: - parameter_schema["enum"] = parameter.options - - parameters["properties"][parameter.name] = parameter_schema - - if parameter.required: - parameters["required"].append(parameter.name) - - return parameters diff --git a/api/graphon/nodes/parameter_extractor/exc.py b/api/graphon/nodes/parameter_extractor/exc.py deleted file mode 100644 index faa90313c1..0000000000 --- a/api/graphon/nodes/parameter_extractor/exc.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Any - -from graphon.variables.types import SegmentType - - -class ParameterExtractorNodeError(ValueError): - """Base error for ParameterExtractorNode.""" - - -class InvalidModelTypeError(ParameterExtractorNodeError): - """Raised when the model is not a Large Language Model.""" - - -class ModelSchemaNotFoundError(ParameterExtractorNodeError): - """Raised when the model schema is not found.""" - - -class InvalidInvokeResultError(ParameterExtractorNodeError): - """Raised when the invoke result is invalid.""" - - -class InvalidTextContentTypeError(ParameterExtractorNodeError): - """Raised when the text content type is invalid.""" - - -class InvalidNumberOfParametersError(ParameterExtractorNodeError): - """Raised when the number of parameters is invalid.""" - - -class RequiredParameterMissingError(ParameterExtractorNodeError): - """Raised when a required parameter is missing.""" - - -class InvalidSelectValueError(ParameterExtractorNodeError): - """Raised when a select value is invalid.""" - - -class InvalidNumberValueError(ParameterExtractorNodeError): - """Raised when a number value is invalid.""" - - -class InvalidBoolValueError(ParameterExtractorNodeError): - """Raised when a bool value is invalid.""" - - -class InvalidStringValueError(ParameterExtractorNodeError): - """Raised when a string value is invalid.""" - - -class InvalidArrayValueError(ParameterExtractorNodeError): - """Raised when an array value is invalid.""" - - -class InvalidModelModeError(ParameterExtractorNodeError): - """Raised when the model mode is invalid.""" - - -class InvalidValueTypeError(ParameterExtractorNodeError): - def __init__( - self, - /, - parameter_name: str, - expected_type: SegmentType, - actual_type: SegmentType | None, - value: Any, - ): - message = ( - f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, " - f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}" - ) - super().__init__(message) - self.parameter_name = parameter_name - self.expected_type = expected_type - self.actual_type = actual_type - self.value = value diff --git a/api/graphon/nodes/parameter_extractor/parameter_extractor_node.py b/api/graphon/nodes/parameter_extractor/parameter_extractor_node.py deleted file mode 100644 index 25379e325c..0000000000 --- a/api/graphon/nodes/parameter_extractor/parameter_extractor_node.py +++ /dev/null @@ -1,846 +0,0 @@ -import contextlib -import json -import logging -import uuid -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import File -from graphon.model_runtime.entities import ImagePromptMessageContent, LLMMode -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageRole, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.node_events import NodeRunResult -from graphon.nodes.base import variable_template_parser -from graphon.nodes.base.node import Node -from graphon.nodes.llm import LLMNode, llm_utils -from graphon.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate -from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol -from graphon.runtime import VariablePool -from graphon.variables import build_segment_with_type -from graphon.variables.types import ArrayValidation, SegmentType - -from .entities import ParameterExtractorNodeData -from .exc import ( - InvalidModelModeError, - InvalidModelTypeError, - InvalidNumberOfParametersError, - InvalidSelectValueError, - InvalidTextContentTypeError, - InvalidValueTypeError, - ModelSchemaNotFoundError, - ParameterExtractorNodeError, - RequiredParameterMissingError, -) -from .prompts import ( - CHAT_EXAMPLE, - CHAT_GENERATE_JSON_PROMPT, - CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, - COMPLETION_GENERATE_JSON_PROMPT, - FUNCTION_CALLING_EXTRACTOR_EXAMPLE, - FUNCTION_CALLING_EXTRACTOR_NAME, - FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT, - FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE, -) - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - - -def extract_json(text): - """ - From a given JSON started from '{' or '[' extract the complete JSON object. - """ - stack = [] - for i, c in enumerate(text): - if c in {"{", "["}: - stack.append(c) - elif c in {"}", "]"}: - # check if stack is empty - if not stack: - return text[:i] - # check if the last element in stack is matching - if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["): - stack.pop() - if not stack: - return text[: i + 1] - else: - return text[:i] - return None - - -class ParameterExtractorNode(Node[ParameterExtractorNodeData]): - """ - Parameter Extractor Node. - """ - - node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR - - _model_instance: PreparedLLMProtocol - _prompt_message_serializer: PromptMessageSerializerProtocol - _memory: PromptMessageMemory | None - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - credentials_provider: object | None = None, - model_factory: object | None = None, - model_instance: PreparedLLMProtocol, - memory: PromptMessageMemory | None = None, - prompt_message_serializer: PromptMessageSerializerProtocol, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - _ = credentials_provider, model_factory - self._model_instance = model_instance - self._prompt_message_serializer = prompt_message_serializer - self._memory = memory - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return { - "model": { - "prompt_templates": { - "completion_model": { - "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, - "stop": ["Human:"], - } - } - } - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self): - """ - Run the node. - """ - node_data = self.node_data - variable = self.graph_runtime_state.variable_pool.get(node_data.query) - query = variable.text if variable else "" - - variable_pool = self.graph_runtime_state.variable_pool - - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=node_data.vision.configs.variable_selector, - ) - if node_data.vision.enabled - else [] - ) - - model_instance = self._model_instance - # Resolve variable references in string-typed completion params - model_instance.parameters = llm_utils.resolve_completion_params_variables( - model_instance.parameters, variable_pool - ) - try: - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - except ValueError as exc: - raise ModelSchemaNotFoundError("Model schema not found") from exc - if model_schema.model_type != ModelType.LLM: - raise InvalidModelTypeError("Model is not a Large Language Model") - memory = self._memory - - if ( - set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} - and node_data.reasoning_mode == "function_call" - ): - # use function call - prompt_messages, prompt_message_tools = self._generate_function_call_prompt( - node_data=node_data, - query=query, - variable_pool=self.graph_runtime_state.variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=node_data.vision.configs.detail, - ) - else: - # use prompt engineering - prompt_messages = self._generate_prompt_engineering_prompt( - data=node_data, - query=query, - variable_pool=self.graph_runtime_state.variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=node_data.vision.configs.detail, - ) - - prompt_message_tools = [] - - inputs = { - "query": query, - "files": [f.to_dict() for f in files], - "parameters": jsonable_encoder(node_data.parameters), - "instruction": jsonable_encoder(node_data.instruction), - } - - process_data = { - "model_mode": node_data.model.mode, - "prompts": self._prompt_message_serializer.serialize( - model_mode=node_data.model.mode, - prompt_messages=prompt_messages, - ), - "usage": None, - "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), - "tool_call": None, - "model_provider": model_instance.provider, - "model_name": model_instance.model_name, - } - - try: - text, usage, tool_call = self._invoke( - model_instance=model_instance, - prompt_messages=prompt_messages, - tools=prompt_message_tools, - stop=model_instance.stop, - ) - process_data["usage"] = jsonable_encoder(usage) - process_data["tool_call"] = jsonable_encoder(tool_call) - process_data["llm_text"] = text - except ParameterExtractorNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - outputs={"__is_success": 0, "__reason": str(e)}, - error=str(e), - metadata={}, - ) - except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)}, - error=str(e), - metadata={}, - ) - - error = None - - if tool_call: - result = self._extract_json_from_tool_call(tool_call) - else: - result = self._extract_complete_json_response(text) - if not result: - result = self._generate_default_result(node_data) - error = "Failed to extract result from function call or text response, using empty result." - - try: - result = self._validate_result(data=node_data, result=result or {}) - except ParameterExtractorNodeError as e: - error = str(e) - - # transform result into standard format - result = self._transform_result(data=node_data, result=result or {}) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={ - "__is_success": 1 if not error else 0, - "__reason": error, - "__usage": jsonable_encoder(usage), - **result, - }, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - - def _invoke( - self, - model_instance: PreparedLLMProtocol, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - stop: Sequence[str] | None, - ) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]: - invoke_result = cast( - LLMResult, - model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=dict(model_instance.parameters), - tools=tools or None, - stop=stop, - stream=False, - ), - ) - - # handle invoke result - - text = invoke_result.message.get_text_content() - if not isinstance(text, str): - raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.") - - usage = invoke_result.usage - tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None - - return text, usage, tool_call - - def _generate_function_call_prompt( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: PreparedLLMProtocol, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: - """ - Generate function call prompt. - """ - query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format( - content=query, structure=json.dumps(node_data.get_parameter_json_schema()) - ) - - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - context="", - ) - prompt_template = self._get_function_calling_prompt_template( - node_data, query, variable_pool, memory, rest_token - ) - prompt_messages = self._compile_prompt_messages( - model_instance=model_instance, - prompt_template=prompt_template, - files=files, - vision_enabled=node_data.vision.enabled, - image_detail_config=vision_detail, - ) - - # find last user message - last_user_message_idx = -1 - for i, prompt_message in enumerate(prompt_messages): - if prompt_message.role == PromptMessageRole.USER: - last_user_message_idx = i - - # add function call messages before last user message - example_messages = [] - for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE: - id = uuid.uuid4().hex - example_messages.extend( - [ - UserPromptMessage(content=example["user"]["query"]), - AssistantPromptMessage( - content=example["assistant"]["text"], - tool_calls=[ - AssistantPromptMessage.ToolCall( - id=id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=example["assistant"]["function_call"]["name"], - arguments=json.dumps(example["assistant"]["function_call"]["parameters"]), - ), - ) - ], - ), - ToolPromptMessage( - content="Great! You have called the function with the correct parameters.", tool_call_id=id - ), - AssistantPromptMessage( - content="I have extracted the parameters, let's move on.", - ), - ] - ) - - prompt_messages = ( - prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] - ) - - # generate tool - tool = PromptMessageTool( - name=FUNCTION_CALLING_EXTRACTOR_NAME, - description="Extract parameters from the natural language text", - parameters=node_data.get_parameter_json_schema(), - ) - - return prompt_messages, [tool] - - def _generate_prompt_engineering_prompt( - self, - data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: PreparedLLMProtocol, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - """ - Generate prompt engineering prompt. - """ - if data.model.mode == LLMMode.COMPLETION: - return self._generate_prompt_engineering_completion_prompt( - node_data=data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=vision_detail, - ) - if data.model.mode == LLMMode.CHAT: - return self._generate_prompt_engineering_chat_prompt( - node_data=data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=vision_detail, - ) - raise InvalidModelModeError(f"Invalid model mode: {data.model.mode}") - - def _generate_prompt_engineering_completion_prompt( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: PreparedLLMProtocol, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - """ - Generate completion prompt. - """ - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - context="", - ) - prompt_template = self._get_prompt_engineering_prompt_template( - node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token - ) - return self._compile_prompt_messages( - model_instance=model_instance, - prompt_template=prompt_template, - files=files, - vision_enabled=node_data.vision.enabled, - image_detail_config=vision_detail, - ) - - def _generate_prompt_engineering_chat_prompt( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: PreparedLLMProtocol, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - """ - Generate chat prompt. - """ - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - context="", - ) - prompt_template = self._get_prompt_engineering_prompt_template( - node_data=node_data, - query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(node_data.get_parameter_json_schema()), text=query - ), - variable_pool=variable_pool, - memory=memory, - max_token_limit=rest_token, - ) - - prompt_messages = self._compile_prompt_messages( - model_instance=model_instance, - prompt_template=prompt_template, - files=files, - vision_enabled=node_data.vision.enabled, - image_detail_config=vision_detail, - ) - - # find last user message - last_user_message_idx = -1 - for i, prompt_message in enumerate(prompt_messages): - if prompt_message.role == PromptMessageRole.USER: - last_user_message_idx = i - - # add example messages before last user message - example_messages = [] - for example in CHAT_EXAMPLE: - example_messages.extend( - [ - UserPromptMessage( - content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(example["user"]["json"]), - text=example["user"]["query"], - ) - ), - AssistantPromptMessage( - content=json.dumps(example["assistant"]["json"]), - ), - ] - ) - - prompt_messages = ( - prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] - ) - - return prompt_messages - - def _validate_result(self, data: ParameterExtractorNodeData, result: dict): - if len(data.parameters) != len(result): - raise InvalidNumberOfParametersError("Invalid number of parameters") - - for parameter in data.parameters: - if parameter.required and parameter.name not in result: - raise RequiredParameterMissingError(f"Parameter {parameter.name} is required") - - param_value = result.get(parameter.name) - if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL): - inferred_type = SegmentType.infer_segment_type(param_value) - raise InvalidValueTypeError( - parameter_name=parameter.name, - expected_type=parameter.type, - actual_type=inferred_type, - value=param_value, - ) - if parameter.type == SegmentType.STRING and parameter.options: - if param_value not in parameter.options: - raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}") - return result - - @staticmethod - def _transform_number(value: int | float | str | bool) -> int | float | None: - """ - Attempts to transform the input into an integer or float. - - Returns: - int or float: The transformed number if the conversion is successful. - None: If the transformation fails. - - Note: - Boolean values `True` and `False` are converted to integers `1` and `0`, respectively. - This behavior ensures compatibility with existing workflows that may use boolean types as integers. - """ - if isinstance(value, bool): - return int(value) - elif isinstance(value, (int, float)): - return value - elif isinstance(value, str): - if "." in value: - try: - return float(value) - except ValueError: - return None - else: - try: - return int(value) - except ValueError: - return None - else: - return None - - def _transform_result(self, data: ParameterExtractorNodeData, result: dict): - """ - Transform result into standard format. - """ - transformed_result: dict[str, Any] = {} - for parameter in data.parameters: - if parameter.name in result: - param_value = result[parameter.name] - # transform value - if parameter.type == SegmentType.NUMBER: - transformed = self._transform_number(param_value) - if transformed is not None: - transformed_result[parameter.name] = transformed - elif parameter.type == SegmentType.BOOLEAN: - if isinstance(result[parameter.name], (bool, int)): - transformed_result[parameter.name] = bool(result[parameter.name]) - # elif isinstance(result[parameter.name], str): - # if result[parameter.name].lower() in ["true", "false"]: - # transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true") - elif parameter.type == SegmentType.STRING: - if isinstance(param_value, str): - transformed_result[parameter.name] = param_value - elif parameter.is_array_type(): - if isinstance(param_value, list): - nested_type = parameter.element_type() - assert nested_type is not None - segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[]) - transformed_result[parameter.name] = segment_value - for item in param_value: - if nested_type == SegmentType.NUMBER: - transformed = self._transform_number(item) - if transformed is not None: - segment_value.value.append(transformed) - elif nested_type == SegmentType.STRING: - if isinstance(item, str): - segment_value.value.append(item) - elif nested_type == SegmentType.OBJECT: - if isinstance(item, dict): - segment_value.value.append(item) - elif nested_type == SegmentType.BOOLEAN: - if isinstance(item, bool): - segment_value.value.append(item) - - if parameter.name not in transformed_result: - if parameter.type.is_array_type(): - transformed_result[parameter.name] = build_segment_with_type( - segment_type=SegmentType(parameter.type), value=[] - ) - elif parameter.type in (SegmentType.STRING, SegmentType.SECRET): - transformed_result[parameter.name] = "" - elif parameter.type == SegmentType.NUMBER: - transformed_result[parameter.name] = 0 - elif parameter.type == SegmentType.BOOLEAN: - transformed_result[parameter.name] = False - else: - raise AssertionError("this statement should be unreachable.") - - return transformed_result - - def _extract_complete_json_response(self, result: str) -> dict | None: - """ - Extract complete json response. - """ - - # extract json from the text - for idx in range(len(result)): - if result[idx] == "{" or result[idx] == "[": - json_str = extract_json(result[idx:]) - if json_str: - with contextlib.suppress(Exception): - return cast(dict, json.loads(json_str)) - logger.info("extra error: %s", result) - return None - - def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None: - """ - Extract json from tool call. - """ - if not tool_call or not tool_call.function.arguments: - return None - - result = tool_call.function.arguments - # extract json from the arguments - for idx in range(len(result)): - if result[idx] == "{" or result[idx] == "[": - json_str = extract_json(result[idx:]) - if json_str: - with contextlib.suppress(Exception): - return cast(dict, json.loads(json_str)) - - logger.info("extra error: %s", result) - return None - - def _generate_default_result(self, data: ParameterExtractorNodeData): - """ - Generate default result. - """ - result: dict[str, Any] = {} - for parameter in data.parameters: - if parameter.type == "number": - result[parameter.name] = 0 - elif parameter.type == "boolean": - result[parameter.name] = False - elif parameter.type in {"string", "select"}: - result[parameter.name] = "" - - return result - - def _get_function_calling_prompt_template( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - memory: PromptMessageMemory | None, - max_token_limit: int = 2000, - ) -> list[LLMNodeChatModelMessage]: - input_text = query - memory_str = "" - instruction = variable_pool.convert_template(node_data.instruction or "").text - - if memory and node_data.memory and node_data.memory.window: - memory_str = llm_utils.fetch_memory_text( - memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size - ) - if node_data.model.mode == LLMMode.CHAT: - system_prompt_messages = LLMNodeChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), - ) - user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text) - return [system_prompt_messages, user_prompt_message] - raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.") - - def _get_prompt_engineering_prompt_template( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - memory: PromptMessageMemory | None, - max_token_limit: int = 2000, - ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: - input_text = query - memory_str = "" - instruction = variable_pool.convert_template(node_data.instruction or "").text - - if memory and node_data.memory and node_data.memory.window: - memory_str = llm_utils.fetch_memory_text( - memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size - ) - if node_data.model.mode == LLMMode.CHAT: - system_prompt_messages = LLMNodeChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction), - ) - user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text) - return [system_prompt_messages, user_prompt_message] - if node_data.model.mode == LLMMode.COMPLETION: - return LLMNodeCompletionModelPromptTemplate( - text=COMPLETION_GENERATE_JSON_PROMPT.format( - histories=memory_str, text=input_text, instruction=instruction - ) - .replace("{γγγ", "") - .replace("}γγγ", "") - .replace("{ structure }", json.dumps(node_data.get_parameter_json_schema())), - ) - raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.") - - def _calculate_rest_token( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: PreparedLLMProtocol, - context: str | None, - ) -> int: - try: - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - except ValueError as exc: - raise ModelSchemaNotFoundError("Model schema not found") from exc - - prompt_template: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: - prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) - else: - prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) - - prompt_messages = self._compile_prompt_messages( - model_instance=model_instance, - prompt_template=prompt_template, - files=[], - vision_enabled=False, - context=context, - ) - rest_tokens = 2000 - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + 1000 - - max_tokens = 0 - for parameter_rule in model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_instance.parameters.get(parameter_rule.name) - or model_instance.parameters.get(parameter_rule.use_template or "") - ) or 0 - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - def _compile_prompt_messages( - self, - *, - model_instance: PreparedLLMProtocol, - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, - files: Sequence[File], - vision_enabled: bool, - context: str | None = "", - image_detail_config: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - prompt_messages, _ = LLMNode.fetch_prompt_messages( - sys_query="", - sys_files=files, - context=context or "", - memory=None, - model_instance=model_instance, - prompt_template=prompt_template, - stop=model_instance.stop, - memory_config=None, - vision_enabled=vision_enabled, - vision_detail=image_detail_config or ImagePromptMessageContent.DETAIL.HIGH, - variable_pool=self.graph_runtime_state.variable_pool, - jinja2_variables=[], - ) - return list(prompt_messages) - - @property - def model_instance(self) -> PreparedLLMProtocol: - return self._model_instance - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ParameterExtractorNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} - - if node_data.instruction: - selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) - for selector in selectors: - variable_mapping[selector.variable] = selector.value_selector - - variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} - - return variable_mapping diff --git a/api/graphon/nodes/parameter_extractor/prompts.py b/api/graphon/nodes/parameter_extractor/prompts.py deleted file mode 100644 index 1b29be4418..0000000000 --- a/api/graphon/nodes/parameter_extractor/prompts.py +++ /dev/null @@ -1,184 +0,0 @@ -from typing import Any - -FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters" - -FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. -### Task -Always call the `{FUNCTION_CALLING_EXTRACTOR_NAME}` function with the correct parameters. Ensure that the information extraction is contextual and aligns with the provided criteria. -### Memory -Here is the chat history between the human and assistant, provided within tags: - -\x7bhistories\x7d - -### Instructions: -Some additional information is provided below. Always adhere to these instructions as closely as possible: - -\x7binstruction\x7d - -Steps: -1. Review the chat history provided within the tags. -2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text. -3. Generate a well-formatted output using the defined functions and arguments. -4. Use the `extract_parameter` function to create structured outputs with appropriate parameters. -5. Do not include any XML tags in your output. -### Example -To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples. -### Final Output -Produce well-formatted function calls in json without XML tags, as shown in the example. -""" # noqa: E501 - -FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside XML tags. - -\x7bcontent\x7d - - - -\x7bstructure\x7d - -""" # noqa: E501 - -FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [ - { - "user": { - "query": "What is the weather today in SF?", - "function": { - "name": FUNCTION_CALLING_EXTRACTOR_NAME, - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location to get the weather information", - "required": True, - }, - }, - "required": ["location"], - }, - }, - }, - "assistant": { - "text": "I need always call the function with the correct parameters." - " in this case, I need to call the function with the location parameter.", - "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"location": "San Francisco"}}, - }, - }, - { - "user": { - "query": "I want to eat some apple pie.", - "function": { - "name": FUNCTION_CALLING_EXTRACTOR_NAME, - "parameters": { - "type": "object", - "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, - "required": ["food"], - }, - }, - }, - "assistant": { - "text": "I need always call the function with the correct parameters." - " in this case, I need to call the function with the food parameter.", - "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"food": "apple pie"}}, - }, - }, -] - -COMPLETION_GENERATE_JSON_PROMPT = """### Instructions: -Some extra information are provided below, I should always follow the instructions as possible as I can. - -{instruction} - - -### Extract parameter Workflow -I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted. - -{{ structure }} - - -Step 1: Carefully read the input and understand the structure of the expected output. -Step 2: Extract relevant parameters from the provided text based on the name and description of object. -Step 3: Structure the extracted parameters to JSON object as specified in . -Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted. - -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - - -### Structure -Here is the structure of the expected output, I should always follow the output structure. -{{γγγ - 'properties1': 'relevant text extracted from input', - 'properties2': 'relevant text extracted from input', -}}γγγ - -### Input Text -Inside XML tags, there is a text that I should extract parameters and convert to a JSON object. - -{text} - - -### Answer -I should always output a valid JSON object. Output nothing other than the JSON object. -```JSON -""" # noqa: E501 - -CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object. -The structure of the JSON object you can found in the instructions. - -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - - -### Instructions: -Some extra information are provided below, you should always follow the instructions as possible as you can. - -{instructions} - -""" - -CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE = """### Structure -Here is the structure of the JSON object, you should always follow the structure. - -{structure} - - -### Text to be converted to JSON -Inside XML tags, there is a text that you should convert to a JSON object. - -{text} - -""" - -CHAT_EXAMPLE = [ - { - "user": { - "query": "What is the weather today in SF?", - "json": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location to get the weather information", - "required": True, - } - }, - "required": ["location"], - }, - }, - "assistant": {"text": "I need to output a valid JSON object.", "json": {"location": "San Francisco"}}, - }, - { - "user": { - "query": "I want to eat some apple pie.", - "json": { - "type": "object", - "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, - "required": ["food"], - }, - }, - "assistant": {"text": "I need to output a valid JSON object.", "json": {"food": "apple pie"}}, - }, -] diff --git a/api/graphon/nodes/protocols.py b/api/graphon/nodes/protocols.py deleted file mode 100644 index 4b050c113c..0000000000 --- a/api/graphon/nodes/protocols.py +++ /dev/null @@ -1,46 +0,0 @@ -from collections.abc import Generator, Mapping -from typing import Any, Protocol - -import httpx - -from graphon.file import File - - -class HttpClientProtocol(Protocol): - @property - def max_retries_exceeded_error(self) -> type[Exception]: ... - - @property - def request_error(self) -> type[Exception]: ... - - def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - -class FileManagerProtocol(Protocol): - def download(self, f: File, /) -> bytes: ... - - -class ToolFileManagerProtocol(Protocol): - def create_file_by_raw( - self, - *, - file_binary: bytes, - mimetype: str, - filename: str | None = None, - ) -> Any: ... - - def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, File | None]: ... - - -class FileReferenceFactoryProtocol(Protocol): - def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: ... diff --git a/api/graphon/nodes/question_classifier/__init__.py b/api/graphon/nodes/question_classifier/__init__.py deleted file mode 100644 index 4d06b6bea3..0000000000 --- a/api/graphon/nodes/question_classifier/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .entities import QuestionClassifierNodeData -from .question_classifier_node import QuestionClassifierNode - -__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeData"] diff --git a/api/graphon/nodes/question_classifier/entities.py b/api/graphon/nodes/question_classifier/entities.py deleted file mode 100644 index 8d5f117315..0000000000 --- a/api/graphon/nodes/question_classifier/entities.py +++ /dev/null @@ -1,30 +0,0 @@ -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.llm import ModelConfig, VisionConfig -from graphon.prompt_entities import MemoryConfig - - -class ClassConfig(BaseModel): - id: str - name: str - - -class QuestionClassifierNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.QUESTION_CLASSIFIER - query_variable_selector: list[str] - model: ModelConfig - classes: list[ClassConfig] - instruction: str | None = None - memory: MemoryConfig | None = None - vision: VisionConfig = Field(default_factory=VisionConfig) - - @property - def structured_output_enabled(self) -> bool: - # NOTE(QuantumGhost): Temporary workaround for issue #20725 - # (https://github.com/langgenius/dify/issues/20725). - # - # The proper fix would be to make `QuestionClassifierNode` inherit - # from `BaseNode` instead of `LLMNode`. - return False diff --git a/api/graphon/nodes/question_classifier/exc.py b/api/graphon/nodes/question_classifier/exc.py deleted file mode 100644 index 2c6354e2a7..0000000000 --- a/api/graphon/nodes/question_classifier/exc.py +++ /dev/null @@ -1,6 +0,0 @@ -class QuestionClassifierNodeError(ValueError): - """Base class for QuestionClassifierNode errors.""" - - -class InvalidModelTypeError(QuestionClassifierNodeError): - """Raised when the model is not a Large Language Model.""" diff --git a/api/graphon/nodes/question_classifier/question_classifier_node.py b/api/graphon/nodes/question_classifier/question_classifier_node.py deleted file mode 100644 index a30ffbb149..0000000000 --- a/api/graphon/nodes/question_classifier/question_classifier_node.py +++ /dev/null @@ -1,395 +0,0 @@ -import json -import re -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ( - BuiltinNodeTypes, - NodeExecutionType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.model_runtime.entities import LLMMode, LLMUsage, ModelPropertyKey, PromptMessageRole -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.node_events import ModelInvokeCompletedEvent, NodeRunResult -from graphon.nodes.base.entities import VariableSelector -from graphon.nodes.base.node import Node -from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from graphon.nodes.llm import ( - LLMNode, - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - llm_utils, -) -from graphon.nodes.llm.file_saver import LLMFileSaver -from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol -from graphon.nodes.protocols import HttpClientProtocol -from graphon.template_rendering import Jinja2TemplateRenderer -from graphon.utils.json_in_md_parser import parse_and_check_json_markdown - -from .entities import QuestionClassifierNodeData -from .exc import InvalidModelTypeError -from .template_prompts import ( - QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, - QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2, - QUESTION_CLASSIFIER_COMPLETION_PROMPT, - QUESTION_CLASSIFIER_SYSTEM_PROMPT, - QUESTION_CLASSIFIER_USER_PROMPT_1, - QUESTION_CLASSIFIER_USER_PROMPT_2, - QUESTION_CLASSIFIER_USER_PROMPT_3, -) - -if TYPE_CHECKING: - from graphon.file.models import File - from graphon.runtime import GraphRuntimeState - - -class _PassthroughPromptMessageSerializer: - def serialize(self, *, model_mode: Any, prompt_messages: Sequence[Any]) -> Any: - _ = model_mode - return list(prompt_messages) - - -class QuestionClassifierNode(Node[QuestionClassifierNodeData]): - node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER - execution_type = NodeExecutionType.BRANCH - - _file_outputs: list["File"] - _llm_file_saver: LLMFileSaver - _prompt_message_serializer: PromptMessageSerializerProtocol - _model_instance: PreparedLLMProtocol - _memory: PromptMessageMemory | None - _template_renderer: Jinja2TemplateRenderer - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - credentials_provider: object | None = None, - model_factory: object | None = None, - model_instance: PreparedLLMProtocol, - http_client: HttpClientProtocol, - template_renderer: Jinja2TemplateRenderer, - memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver, - prompt_message_serializer: PromptMessageSerializerProtocol | None = None, - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - # LLM file outputs, used for MultiModal outputs. - self._file_outputs = [] - - _ = credentials_provider, model_factory, http_client - self._model_instance = model_instance - self._memory = memory - self._template_renderer = template_renderer - - self._llm_file_saver = llm_file_saver - self._prompt_message_serializer = prompt_message_serializer or _PassthroughPromptMessageSerializer() - - @classmethod - def version(cls): - return "1" - - def _run(self): - node_data = self.node_data - variable_pool = self.graph_runtime_state.variable_pool - - # extract variables - variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None - query = variable.value if variable else None - variables = {"query": query} - # fetch model instance - model_instance = self._model_instance - # Resolve variable references in string-typed completion params - model_instance.parameters = llm_utils.resolve_completion_params_variables( - model_instance.parameters, variable_pool - ) - memory = self._memory - # fetch instruction - node_data.instruction = node_data.instruction or "" - node_data.instruction = variable_pool.convert_template(node_data.instruction).text - - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=node_data.vision.configs.variable_selector, - ) - if node_data.vision.enabled - else [] - ) - - # fetch prompt messages - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query or "", - model_instance=model_instance, - context="", - ) - prompt_template = self._get_prompt_template( - node_data=node_data, - query=query or "", - memory=memory, - max_token_limit=rest_token, - ) - # Some models (e.g. Gemma, Mistral) force roles alternation (user/assistant/user/assistant...). - # If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt, - # two consecutive user prompts will be generated, causing model's error. - # To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end. - prompt_messages, stop = llm_utils.fetch_prompt_messages( - prompt_template=prompt_template, - sys_query="", - memory=memory, - model_instance=model_instance, - stop=model_instance.stop, - sys_files=files, - vision_enabled=node_data.vision.enabled, - vision_detail=node_data.vision.configs.detail, - variable_pool=variable_pool, - jinja2_variables=[], - template_renderer=self._template_renderer, - ) - - result_text = "" - usage = LLMUsage.empty_usage() - finish_reason = None - - try: - # handle invoke result - generator = LLMNode.invoke_llm( - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - structured_output_enabled=False, - structured_output=None, - file_saver=self._llm_file_saver, - file_outputs=self._file_outputs, - node_id=self._node_id, - node_type=self.node_type, - ) - - for event in generator: - if isinstance(event, ModelInvokeCompletedEvent): - result_text = event.text - usage = event.usage - finish_reason = event.finish_reason - break - - rendered_classes = [ - c.model_copy(update={"name": variable_pool.convert_template(c.name).text}) for c in node_data.classes - ] - - category_name = rendered_classes[0].name - category_id = rendered_classes[0].id - if "" in result_text: - result_text = re.sub(r"]*>[\s\S]*?", "", result_text, flags=re.IGNORECASE) - result_text_json = parse_and_check_json_markdown(result_text, []) - # result_text_json = json.loads(result_text.strip('```JSON\n')) - if "category_name" in result_text_json and "category_id" in result_text_json: - category_id_result = result_text_json["category_id"] - classes = rendered_classes - classes_map = {class_.id: class_.name for class_ in classes} - category_ids = [_class.id for _class in classes] - if category_id_result in category_ids: - category_name = classes_map[category_id_result] - category_id = category_id_result - process_data = { - "model_mode": node_data.model.mode, - "prompts": self._prompt_message_serializer.serialize( - model_mode=node_data.model.mode, prompt_messages=prompt_messages - ), - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - "model_provider": model_instance.provider, - "model_name": model_instance.model_name, - } - outputs = { - "class_name": category_name, - "class_id": category_id, - "usage": jsonable_encoder(usage), - } - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - process_data=process_data, - outputs=outputs, - edge_source_handle=category_id, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - except ValueError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e), - error_type=type(e).__name__, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - - @property - def model_instance(self) -> PreparedLLMProtocol: - return self._model_instance - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: QuestionClassifierNodeData, - ) -> Mapping[str, Sequence[str]]: - # graph_config is not used in this node type - variable_mapping = {"query": node_data.query_variable_selector} - variable_selectors: list[VariableSelector] = [] - if node_data.instruction: - variable_template_parser = VariableTemplateParser(template=node_data.instruction) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = list(variable_selector.value_selector) - - variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} - - return variable_mapping - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """ - Get default config of node. - :param filters: filter by node config parameters (not used in this implementation). - :return: - """ - # filters parameter is not used in this node type - return {"type": "question-classifier", "config": {"instructions": ""}} - - def _calculate_rest_token( - self, - node_data: QuestionClassifierNodeData, - query: str, - model_instance: PreparedLLMProtocol, - context: str | None, - ) -> int: - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - - prompt_template = self._get_prompt_template(node_data, query, None, 2000) - prompt_messages, _ = llm_utils.fetch_prompt_messages( - prompt_template=prompt_template, - sys_query="", - sys_files=[], - context=context or "", - memory=None, - model_instance=model_instance, - stop=model_instance.stop, - memory_config=node_data.memory, - vision_enabled=False, - vision_detail=node_data.vision.configs.detail, - variable_pool=self.graph_runtime_state.variable_pool, - jinja2_variables=[], - template_renderer=self._template_renderer, - ) - rest_tokens = 2000 - - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_instance.parameters.get(parameter_rule.name) - or model_instance.parameters.get(parameter_rule.use_template or "") - ) or 0 - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - def _get_prompt_template( - self, - node_data: QuestionClassifierNodeData, - query: str, - memory: PromptMessageMemory | None, - max_token_limit: int = 2000, - ): - model_mode = LLMMode(node_data.model.mode) - classes = node_data.classes - categories = [] - for class_ in classes: - category = {"category_id": class_.id, "category_name": class_.name} - categories.append(category) - instruction = node_data.instruction or "" - input_text = query - memory_str = "" - if memory: - memory_str = llm_utils.fetch_memory_text( - memory=memory, - max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, - ) - prompt_messages: list[LLMNodeChatModelMessage] = [] - if model_mode == LLMMode.CHAT: - system_prompt_messages = LLMNodeChatModelMessage( - role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) - ) - prompt_messages.append(system_prompt_messages) - user_prompt_message_1 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1 - ) - prompt_messages.append(user_prompt_message_1) - assistant_prompt_message_1 = LLMNodeChatModelMessage( - role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 - ) - prompt_messages.append(assistant_prompt_message_1) - user_prompt_message_2 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2 - ) - prompt_messages.append(user_prompt_message_2) - assistant_prompt_message_2 = LLMNodeChatModelMessage( - role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 - ) - prompt_messages.append(assistant_prompt_message_2) - user_prompt_message_3 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_3.format( - input_text=input_text, - categories=json.dumps(categories, ensure_ascii=False), - classification_instructions=instruction, - ), - ) - prompt_messages.append(user_prompt_message_3) - return prompt_messages - elif model_mode == LLMMode.COMPLETION: - return LLMNodeCompletionModelPromptTemplate( - text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( - histories=memory_str, - input_text=input_text, - categories=json.dumps(categories, ensure_ascii=False), - classification_instructions=instruction, - ) - ) - - else: - raise InvalidModelTypeError(f"Model mode {model_mode} not support.") diff --git a/api/graphon/nodes/question_classifier/template_prompts.py b/api/graphon/nodes/question_classifier/template_prompts.py deleted file mode 100644 index a615c32383..0000000000 --- a/api/graphon/nodes/question_classifier/template_prompts.py +++ /dev/null @@ -1,76 +0,0 @@ -QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ -### Job Description', -You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. -### Task -Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification. -### Format -The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy. -### Constraint -DO NOT include anything other than the JSON array in your response. -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - -""" # noqa: E501 - -QUESTION_CLASSIFIER_USER_PROMPT_1 = """ - {"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], - "categories": [{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"},{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"},{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"},{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}], - "classification_instructions": ["classify the text based on the feedback provided by customer"]} -""" # noqa: E501 - -QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """ -```json - {"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"], - "category_id": "f5660049-284f-41a7-b301-fd24176a711c", - "category_name": "Customer Service"} -``` -""" - -QUESTION_CLASSIFIER_USER_PROMPT_2 = """ - {"input_text": ["bad service, slow to bring the food"], - "categories": [{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"},{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"},{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}], - "classification_instructions": []} -""" # noqa: E501 - -QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """ -```json - {"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"], - "category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f", - "category_name": "Experience"} -``` -""" - -QUESTION_CLASSIFIER_USER_PROMPT_3 = """ - {{"input_text": ["{input_text}"], - "categories": {categories}, - "classification_instructions": ["{classification_instructions}"]}} -""" - -QUESTION_CLASSIFIER_COMPLETION_PROMPT = """ -### Job Description -You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. -### Task -Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification. -### Format -The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy. -### Constraint -DO NOT include anything other than the JSON array in your response. -### Example -Here is the chat example between human and assistant, inside XML tags. - -User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], "categories": [{{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"}},{{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"}},{{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"}},{{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}}], "classification_instructions": ["classify the text based on the feedback provided by customer"]}} -Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"category_id": "f5660049-284f-41a7-b301-fd24176a711c","category_name": "Customer Service"}} -User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"}},{{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"}},{{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}}], "classification_instructions": []}} -Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}} - -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - -### User Input -{{"input_text" : ["{input_text}"], "categories" : {categories},"classification_instruction" : ["{classification_instructions}"]}} -### Assistant Output -""" # noqa: E501 diff --git a/api/graphon/nodes/runtime.py b/api/graphon/nodes/runtime.py deleted file mode 100644 index 650299898c..0000000000 --- a/api/graphon/nodes/runtime.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator, Mapping, Sequence -from datetime import datetime -from typing import TYPE_CHECKING, Any, Protocol - -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.nodes.tool_runtime_entities import ( - ToolRuntimeHandle, - ToolRuntimeMessage, - ToolRuntimeParameter, -) - -if TYPE_CHECKING: - from graphon.nodes.human_input.entities import HumanInputNodeData - from graphon.nodes.human_input.enums import HumanInputFormStatus - from graphon.nodes.tool.entities import ToolNodeData - from graphon.runtime import VariablePool - - -class ToolNodeRuntimeProtocol(Protocol): - """Workflow-layer adapter owned by `core.workflow` and consumed by `graphon`. - - The graph package depends only on these DTOs and lets the workflow layer - translate between graph-owned abstractions and `core.tools` internals. - """ - - def get_runtime( - self, - *, - node_id: str, - node_data: ToolNodeData, - variable_pool: VariablePool | None, - ) -> ToolRuntimeHandle: ... - - def get_runtime_parameters( - self, - *, - tool_runtime: ToolRuntimeHandle, - ) -> Sequence[ToolRuntimeParameter]: ... - - def invoke( - self, - *, - tool_runtime: ToolRuntimeHandle, - tool_parameters: Mapping[str, Any], - workflow_call_depth: int, - provider_name: str, - ) -> Generator[ToolRuntimeMessage, None, None]: ... - - def get_usage( - self, - *, - tool_runtime: ToolRuntimeHandle, - ) -> LLMUsage: ... - - def build_file_reference(self, *, mapping: Mapping[str, Any]) -> Any: ... - - def resolve_provider_icons( - self, - *, - provider_name: str, - default_icon: str | None = None, - ) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]: ... - - -class HumanInputNodeRuntimeProtocol(Protocol): - """Workflow-layer adapter for human-input runtime persistence and delivery.""" - - def get_form( - self, - *, - node_id: str, - ) -> HumanInputFormStateProtocol | None: ... - - def create_form( - self, - *, - node_id: str, - node_data: HumanInputNodeData, - rendered_content: str, - resolved_default_values: Mapping[str, Any], - ) -> HumanInputFormStateProtocol: ... - - -class HumanInputFormStateProtocol(Protocol): - @property - def id(self) -> str: ... - - @property - def rendered_content(self) -> str: ... - - @property - def selected_action_id(self) -> str | None: ... - - @property - def submitted_data(self) -> Mapping[str, Any] | None: ... - - @property - def submitted(self) -> bool: ... - - @property - def status(self) -> HumanInputFormStatus: ... - - @property - def expiration_time(self) -> datetime: ... diff --git a/api/graphon/nodes/start/__init__.py b/api/graphon/nodes/start/__init__.py deleted file mode 100644 index 5411780423..0000000000 --- a/api/graphon/nodes/start/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .start_node import StartNode - -__all__ = ["StartNode"] diff --git a/api/graphon/nodes/start/entities.py b/api/graphon/nodes/start/entities.py deleted file mode 100644 index 7df62e1b2b..0000000000 --- a/api/graphon/nodes/start/entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from collections.abc import Sequence - -from pydantic import Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.variables.input_entities import VariableEntity - - -class StartNodeData(BaseNodeData): - """ - Start Node Data - """ - - type: NodeType = BuiltinNodeTypes.START - variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/graphon/nodes/start/start_node.py b/api/graphon/nodes/start/start_node.py deleted file mode 100644 index cb3f4c1e7d..0000000000 --- a/api/graphon/nodes/start/start_node.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Any - -from jsonschema import Draft7Validator, ValidationError - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.start.entities import StartNodeData -from graphon.variables.input_entities import VariableEntityType - - -class StartNode(Node[StartNodeData]): - node_type = BuiltinNodeTypes.START - execution_type = NodeExecutionType.ROOT - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) - self._validate_and_normalize_json_object_inputs(node_inputs) - outputs = dict(self.graph_runtime_state.variable_pool.flatten(unprefixed_node_id=self.id)) - outputs.update(node_inputs) - - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) - - def _validate_and_normalize_json_object_inputs(self, node_inputs: dict[str, Any]) -> None: - for variable in self.node_data.variables: - if variable.type != VariableEntityType.JSON_OBJECT: - continue - - key = variable.variable - value = node_inputs.get(key) - - if value is None and variable.required: - raise ValueError(f"{key} is required in input form") - - # If no value provided, skip further processing for this key - if not value: - continue - - if not isinstance(value, dict): - raise ValueError(f"JSON object for '{key}' must be an object") - - # Overwrite with normalized dict to ensure downstream consistency - node_inputs[key] = value - - # If schema exists, then validate against it - schema = variable.json_schema - if not schema: - continue - - try: - Draft7Validator(schema).validate(value) - except ValidationError as e: - raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}") diff --git a/api/graphon/nodes/template_transform/__init__.py b/api/graphon/nodes/template_transform/__init__.py deleted file mode 100644 index 43863b9d59..0000000000 --- a/api/graphon/nodes/template_transform/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .template_transform_node import TemplateTransformNode - -__all__ = ["TemplateTransformNode"] diff --git a/api/graphon/nodes/template_transform/entities.py b/api/graphon/nodes/template_transform/entities.py deleted file mode 100644 index a27a57f34f..0000000000 --- a/api/graphon/nodes/template_transform/entities.py +++ /dev/null @@ -1,13 +0,0 @@ -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.base.entities import VariableSelector - - -class TemplateTransformNodeData(BaseNodeData): - """ - Template Transform Node Data. - """ - - type: NodeType = BuiltinNodeTypes.TEMPLATE_TRANSFORM - variables: list[VariableSelector] - template: str diff --git a/api/graphon/nodes/template_transform/template_transform_node.py b/api/graphon/nodes/template_transform/template_transform_node.py deleted file mode 100644 index 4206fb0c1a..0000000000 --- a/api/graphon/nodes/template_transform/template_transform_node.py +++ /dev/null @@ -1,119 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.entities import VariableSelector -from graphon.nodes.base.node import Node -from graphon.nodes.template_transform.entities import TemplateTransformNodeData -from graphon.template_rendering import ( - Jinja2TemplateRenderer, - TemplateRenderError, -) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - -DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000 - - -class TemplateTransformNode(Node[TemplateTransformNodeData]): - node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM - _jinja2_template_renderer: Jinja2TemplateRenderer - _max_output_length: int - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - jinja2_template_renderer: Jinja2TemplateRenderer, - max_output_length: int | None = None, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._jinja2_template_renderer = jinja2_template_renderer - - if max_output_length is not None and max_output_length <= 0: - raise ValueError("max_output_length must be a positive integer") - self._max_output_length = max_output_length or DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ - return { - "type": "template-transform", - "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - # Get variables - variables: dict[str, Any] = {} - for variable_selector in self.node_data.variables: - variable_name = variable_selector.variable - value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - variables[variable_name] = value.to_object() if value else None - # Run code - try: - rendered = self._jinja2_template_renderer.render_template(self.node_data.template, variables) - except TemplateRenderError as e: - return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) - - if len(rendered) > self._max_output_length: - return NodeRunResult( - inputs=variables, - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Output length exceeds {self._max_output_length} characters", - ) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered} - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: TemplateTransformNodeData | Mapping[str, Any], - ) -> Mapping[str, Sequence[str]]: - _ = graph_config - raw_variables = ( - node_data.variables if isinstance(node_data, TemplateTransformNodeData) else node_data.get("variables", []) - ) - variable_mapping: dict[str, Sequence[str]] = {} - for variable_selector in raw_variables: - if isinstance(variable_selector, VariableSelector): - variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector - continue - - if not isinstance(variable_selector, Mapping): - continue - - variable = variable_selector.get("variable") - value_selector = variable_selector.get("value_selector") - if ( - isinstance(variable, str) - and isinstance(value_selector, Sequence) - and all(isinstance(selector_part, str) for selector_part in value_selector) - ): - variable_mapping[node_id + "." + variable] = list(value_selector) - - return variable_mapping diff --git a/api/graphon/nodes/tool/__init__.py b/api/graphon/nodes/tool/__init__.py deleted file mode 100644 index f4982e655d..0000000000 --- a/api/graphon/nodes/tool/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .tool_node import ToolNode - -__all__ = ["ToolNode"] diff --git a/api/graphon/nodes/tool/entities.py b/api/graphon/nodes/tool/entities.py deleted file mode 100644 index 54e6048033..0000000000 --- a/api/graphon/nodes/tool/entities.py +++ /dev/null @@ -1,101 +0,0 @@ -from enum import StrEnum, auto -from typing import Any, Literal, Union - -from pydantic import BaseModel, field_validator -from pydantic_core.core_schema import ValidationInfo - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - - -class ToolProviderType(StrEnum): - """ - Graph-owned enum for persisted tool provider kinds. - """ - - PLUGIN = auto() - BUILT_IN = "builtin" - WORKFLOW = auto() - API = auto() - APP = auto() - DATASET_RETRIEVAL = "dataset-retrieval" - MCP = auto() - - -class ToolEntity(BaseModel): - provider_id: str - provider_type: ToolProviderType - provider_name: str # redundancy - tool_name: str - tool_label: str # redundancy - tool_configurations: dict[str, Any] - credential_id: str | None = None - plugin_unique_identifier: str | None = None # redundancy - - @field_validator("tool_configurations", mode="before") - @classmethod - def validate_tool_configurations(cls, value, values: ValidationInfo): - if not isinstance(value, dict): - raise ValueError("tool_configurations must be a dictionary") - - for key in values.data.get("tool_configurations", {}): - value = values.data.get("tool_configurations", {}).get(key) - if not isinstance(value, str | int | float | bool): - raise ValueError(f"{key} must be a string") - - return value - - -class ToolNodeData(BaseNodeData, ToolEntity): - type: NodeType = BuiltinNodeTypes.TOOL - - class ToolInput(BaseModel): - # TODO: check this type - value: Union[Any, list[str]] - type: Literal["mixed", "variable", "constant"] - - @field_validator("type", mode="before") - @classmethod - def check_type(cls, value, validation_info: ValidationInfo): - typ = value - value = validation_info.data.get("value") - - if value is None: - return typ - - if typ == "mixed" and not isinstance(value, str): - raise ValueError("value must be a string") - elif typ == "variable": - if not isinstance(value, list): - raise ValueError("value must be a list") - for val in value: - if not isinstance(val, str): - raise ValueError("value must be a list of strings") - elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))): - raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}") - return typ - - tool_parameters: dict[str, ToolInput] - # The version of the tool parameter. - # If this value is None, it indicates this is a previous version - # and requires using the legacy parameter parsing rules. - tool_node_version: str | None = None - - @field_validator("tool_parameters", mode="before") - @classmethod - def filter_none_tool_inputs(cls, value): - if not isinstance(value, dict): - return value - - return { - key: tool_input - for key, tool_input in value.items() - if tool_input is not None and cls._has_valid_value(tool_input) - } - - @staticmethod - def _has_valid_value(tool_input): - """Check if the value is valid""" - if isinstance(tool_input, dict): - return tool_input.get("value") is not None - return getattr(tool_input, "value", None) is not None diff --git a/api/graphon/nodes/tool/exc.py b/api/graphon/nodes/tool/exc.py deleted file mode 100644 index 1a309e1084..0000000000 --- a/api/graphon/nodes/tool/exc.py +++ /dev/null @@ -1,28 +0,0 @@ -class ToolNodeError(ValueError): - """Base exception for tool node errors.""" - - pass - - -class ToolRuntimeResolutionError(ToolNodeError): - """Raised when the workflow layer cannot construct a tool runtime.""" - - pass - - -class ToolRuntimeInvocationError(ToolNodeError): - """Raised when the workflow layer fails while invoking a tool runtime.""" - - pass - - -class ToolParameterError(ToolNodeError): - """Exception raised for errors in tool parameters.""" - - pass - - -class ToolFileError(ToolNodeError): - """Exception raised for errors related to tool files.""" - - pass diff --git a/api/graphon/nodes/tool/tool_node.py b/api/graphon/nodes/tool/tool_node.py deleted file mode 100644 index 57ab8ce5d6..0000000000 --- a/api/graphon/nodes/tool/tool_node.py +++ /dev/null @@ -1,432 +0,0 @@ -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from graphon.nodes.base.node import Node -from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from graphon.nodes.protocols import ToolFileManagerProtocol -from graphon.nodes.runtime import ToolNodeRuntimeProtocol -from graphon.nodes.tool_runtime_entities import ( - ToolRuntimeHandle, - ToolRuntimeMessage, - ToolRuntimeParameter, -) -from graphon.variables.segments import ArrayFileSegment - -from .entities import ToolNodeData -from .exc import ( - ToolFileError, - ToolNodeError, - ToolParameterError, -) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - -class ToolNode(Node[ToolNodeData]): - """ - Tool Node - """ - - node_type = BuiltinNodeTypes.TOOL - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - tool_file_manager_factory: ToolFileManagerProtocol, - runtime: ToolNodeRuntimeProtocol | None = None, - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._tool_file_manager_factory = tool_file_manager_factory - if runtime is None: - raise ValueError("runtime is required") - self._runtime = runtime - - @classmethod - def version(cls) -> str: - return "1" - - def populate_start_event(self, event) -> None: - event.provider_id = self.node_data.provider_id - event.provider_type = self.node_data.provider_type - - def _run(self) -> Generator[NodeEventBase, None, None]: - """ - Run the tool node - """ - # fetch tool icon - tool_info = { - "provider_type": self.node_data.provider_type.value, - "provider_id": self.node_data.provider_id, - "plugin_unique_identifier": self.node_data.plugin_unique_identifier, - } - - # get tool runtime - try: - # This is an issue that caused problems before. - # Logically, we shouldn't use the node_data.version field for judgment - # But for backward compatibility with historical data - # this version field judgment is still preserved here. - variable_pool: VariablePool | None = None - if self.node_data.version != "1" or self.node_data.tool_node_version is not None: - variable_pool = self.graph_runtime_state.variable_pool - tool_runtime = self._runtime.get_runtime( - node_id=self._node_id, - node_data=self.node_data, - variable_pool=variable_pool, - ) - except ToolNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs={}, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to get tool runtime: {str(e)}", - error_type=type(e).__name__, - ) - ) - return - - # get parameters - tool_parameters = self._runtime.get_runtime_parameters(tool_runtime=tool_runtime) - parameters = self._generate_parameters( - tool_parameters=tool_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, - ) - parameters_for_log = self._generate_parameters( - tool_parameters=tool_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, - for_log=True, - ) - try: - message_stream = self._runtime.invoke( - tool_runtime=tool_runtime, - tool_parameters=parameters, - workflow_call_depth=self.workflow_call_depth, - provider_name=self.node_data.provider_name, - ) - except ToolNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool: {str(e)}", - error_type=type(e).__name__, - ) - ) - return - - try: - # convert tool messages - _ = yield from self._transform_message( - messages=message_stream, - tool_info=tool_info, - parameters_for_log=parameters_for_log, - node_id=self._node_id, - tool_runtime=tool_runtime, - ) - except ToolNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=str(e), - error_type=type(e).__name__, - ) - ) - - def _generate_parameters( - self, - *, - tool_parameters: Sequence[ToolRuntimeParameter], - variable_pool: "VariablePool", - node_data: ToolNodeData, - for_log: bool = False, - ) -> dict[str, Any]: - """ - Generate parameters based on the given tool parameters, variable pool, and node data. - - Args: - tool_parameters (Sequence[ToolRuntimeParameter]): The list of tool parameters. - variable_pool (VariablePool): The variable pool containing the variables. - node_data (ToolNodeData): The data associated with the tool node. - - Returns: - Mapping[str, Any]: A dictionary containing the generated parameters. - - """ - tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} - - result: dict[str, Any] = {} - for parameter_name in node_data.tool_parameters: - parameter = tool_parameters_dictionary.get(parameter_name) - if not parameter: - result[parameter_name] = None - continue - tool_input = node_data.tool_parameters[parameter_name] - if tool_input.type == "variable": - variable = variable_pool.get(tool_input.value) - if variable is None: - if parameter.required: - raise ToolParameterError(f"Variable {tool_input.value} does not exist") - continue - parameter_value = variable.value - elif tool_input.type in {"mixed", "constant"}: - segment_group = variable_pool.convert_template(str(tool_input.value)) - parameter_value = segment_group.log if for_log else segment_group.text - else: - raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") - result[parameter_name] = parameter_value - - return result - - def _transform_message( - self, - messages: Generator[ToolRuntimeMessage, None, None], - tool_info: Mapping[str, Any], - parameters_for_log: dict[str, Any], - node_id: str, - tool_runtime: ToolRuntimeHandle, - **_: Any, - ) -> Generator[NodeEventBase, None, LLMUsage]: - """ - Convert graph-owned tool runtime messages into node outputs. - """ - text = "" - files: list[File] = [] - json: list[dict | list] = [] - - variables: dict[str, Any] = {} - - for message in messages: - if message.type in { - ToolRuntimeMessage.MessageType.IMAGE_LINK, - ToolRuntimeMessage.MessageType.BINARY_LINK, - ToolRuntimeMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, ToolRuntimeMessage.TextMessage) - - url = message.message.text - if message.meta: - transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - tool_file_id = message.meta.get("tool_file_id") - else: - transfer_method = FileTransferMethod.TOOL_FILE - tool_file_id = None - if not isinstance(tool_file_id, str) or not tool_file_id: - raise ToolFileError("tool message is missing tool_file_id metadata") - - _stream, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) - if not tool_file: - raise ToolFileError(f"tool file {tool_file_id} not found") - if tool_file.mime_type is None: - raise ToolFileError(f"tool file {tool_file_id} is missing mime type") - - file_mapping: dict[str, Any] = { - "tool_file_id": tool_file_id, - "type": get_file_type_by_mime_type(tool_file.mime_type), - "transfer_method": transfer_method, - "url": url, - } - file = self._runtime.build_file_reference(mapping=file_mapping) - files.append(file) - elif message.type == ToolRuntimeMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, ToolRuntimeMessage.TextMessage) - assert message.meta - - tool_file_id = message.meta.get("tool_file_id") - if not isinstance(tool_file_id, str) or not tool_file_id: - raise ToolFileError("tool blob message is missing tool_file_id metadata") - _stream, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) - if not tool_file: - raise ToolFileError(f"tool file {tool_file_id} not exists") - - blob_file_mapping: dict[str, Any] = { - "tool_file_id": tool_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append(self._runtime.build_file_reference(mapping=blob_file_mapping)) - elif message.type == ToolRuntimeMessage.MessageType.TEXT: - assert isinstance(message.message, ToolRuntimeMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - elif message.type == ToolRuntimeMessage.MessageType.JSON: - assert isinstance(message.message, ToolRuntimeMessage.JsonMessage) - # JSON message handling for tool node - if message.message.json_object: - json.append(message.message.json_object) - elif message.type == ToolRuntimeMessage.MessageType.LINK: - assert isinstance(message.message, ToolRuntimeMessage.TextMessage) - - # Check if this LINK message is a file link - file_obj = (message.meta or {}).get("file") - if isinstance(file_obj, File): - files.append(file_obj) - stream_text = f"File: {message.message.text}\n" - else: - stream_text = f"Link: {message.message.text}\n" - - text += stream_text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=stream_text, - is_final=False, - ) - elif message.type == ToolRuntimeMessage.MessageType.VARIABLE: - assert isinstance(message.message, ToolRuntimeMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.") - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value - - yield StreamChunkEvent( - selector=[node_id, variable_name], - chunk=variable_value, - is_final=False, - ) - else: - variables[variable_name] = variable_value - elif message.type == ToolRuntimeMessage.MessageType.FILE: - assert message.meta is not None - assert isinstance(message.meta, dict) - # Validate that meta contains a 'file' key - if "file" not in message.meta: - raise ToolNodeError("File message is missing 'file' key in meta") - - # Validate that the file is an instance of File - if not isinstance(message.meta["file"], File): - raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") - files.append(message.meta["file"]) - elif message.type == ToolRuntimeMessage.MessageType.LOG: - assert isinstance(message.message, ToolRuntimeMessage.LogMessage) - if message.message.metadata: - icon = tool_info.get("icon", "") - dict_metadata = dict(message.message.metadata) - if dict_metadata.get("provider"): - icon, icon_dark = self._runtime.resolve_provider_icons( - provider_name=dict_metadata["provider"], - default_icon=icon, - ) - dict_metadata["icon"] = icon - dict_metadata["icon_dark"] = icon_dark - message.message.metadata = dict_metadata - - # Add agent_logs to outputs['json'] to ensure frontend can access thinking process - json_output: list[dict[str, Any] | list[Any]] = [] - - # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] - if json: - json_output.extend(json) - else: - json_output.append({"data": []}) - - # Send final chunk events for all streamed outputs - # Final chunk for text stream - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) - - # Final chunks for any streamed variables - for var_name in variables: - yield StreamChunkEvent( - selector=[self._node_id, var_name], - chunk="", - is_final=True, - ) - - usage = self._runtime.get_usage(tool_runtime=tool_runtime) - - metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { - WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, - } - if isinstance(usage.total_tokens, int) and usage.total_tokens > 0: - metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens - metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price - metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, - metadata=metadata, - inputs=parameters_for_log, - llm_usage=usage, - ) - ) - - return usage - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ToolNodeData, - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - _ = graph_config # Explicitly mark as unused - typed_node_data = node_data - result = {} - for parameter_name in typed_node_data.tool_parameters: - input = typed_node_data.tool_parameters[parameter_name] - match input.type: - case "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - case "variable": - selector_key = ".".join(input.value) - result[f"#{selector_key}#"] = input.value - case "constant": - pass - - result = {node_id + "." + key: value for key, value in result.items()} - - return result - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled diff --git a/api/graphon/nodes/tool_runtime_entities.py b/api/graphon/nodes/tool_runtime_entities.py deleted file mode 100644 index 5bb0c16573..0000000000 --- a/api/graphon/nodes/tool_runtime_entities.py +++ /dev/null @@ -1,105 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from enum import StrEnum, auto -from typing import Any - -from pydantic import BaseModel, ConfigDict, Field - - -class _ToolRuntimeModel(BaseModel): - model_config = ConfigDict(extra="forbid") - - -@dataclass(frozen=True, slots=True) -class ToolRuntimeHandle: - """Opaque graph-owned handle for a workflow-layer tool runtime. - - Workflow-specific execution context must stay behind `raw` so the graph - contract does not absorb application-owned concepts. - """ - - raw: object - - -@dataclass(frozen=True, slots=True) -class ToolRuntimeParameter: - """Graph-owned parameter shape used by tool nodes.""" - - name: str - required: bool = False - - -class ToolRuntimeMessage(_ToolRuntimeModel): - """Graph-owned tool invocation message DTO.""" - - class TextMessage(_ToolRuntimeModel): - text: str - - class JsonMessage(_ToolRuntimeModel): - json_object: dict[str, Any] | list[Any] - suppress_output: bool = Field(default=False) - - class BlobMessage(_ToolRuntimeModel): - blob: bytes - - class BlobChunkMessage(_ToolRuntimeModel): - id: str - sequence: int - total_length: int - blob: bytes - end: bool - - class FileMessage(_ToolRuntimeModel): - file_marker: str = Field(default="file_marker") - - class VariableMessage(_ToolRuntimeModel): - variable_name: str - variable_value: dict[str, Any] | list[Any] | str | int | float | bool | None - stream: bool = Field(default=False) - - class LogMessage(_ToolRuntimeModel): - class LogStatus(StrEnum): - START = auto() - ERROR = auto() - SUCCESS = auto() - - id: str - label: str - parent_id: str | None = None - error: str | None = None - status: LogStatus - data: dict[str, Any] - metadata: dict[str, Any] = Field(default_factory=dict) - - class RetrieverResourceMessage(_ToolRuntimeModel): - retriever_resources: list[dict[str, Any]] - context: str - - class MessageType(StrEnum): - TEXT = auto() - IMAGE = auto() - LINK = auto() - BLOB = auto() - JSON = auto() - IMAGE_LINK = auto() - BINARY_LINK = auto() - VARIABLE = auto() - FILE = auto() - LOG = auto() - BLOB_CHUNK = auto() - RETRIEVER_RESOURCES = auto() - - type: MessageType = MessageType.TEXT - message: ( - JsonMessage - | TextMessage - | BlobChunkMessage - | BlobMessage - | LogMessage - | FileMessage - | None - | VariableMessage - | RetrieverResourceMessage - ) - meta: dict[str, Any] | None = None diff --git a/api/graphon/nodes/variable_aggregator/__init__.py b/api/graphon/nodes/variable_aggregator/__init__.py deleted file mode 100644 index 0b6bf2a5b6..0000000000 --- a/api/graphon/nodes/variable_aggregator/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .variable_aggregator_node import VariableAggregatorNode - -__all__ = ["VariableAggregatorNode"] diff --git a/api/graphon/nodes/variable_aggregator/entities.py b/api/graphon/nodes/variable_aggregator/entities.py deleted file mode 100644 index 136fd28f8c..0000000000 --- a/api/graphon/nodes/variable_aggregator/entities.py +++ /dev/null @@ -1,35 +0,0 @@ -from pydantic import BaseModel - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.variables.types import SegmentType - - -class AdvancedSettings(BaseModel): - """ - Advanced setting. - """ - - group_enabled: bool - - class Group(BaseModel): - """ - Group. - """ - - output_type: SegmentType - variables: list[list[str]] - group_name: str - - groups: list[Group] - - -class VariableAggregatorNodeData(BaseNodeData): - """ - Variable Aggregator Node Data. - """ - - type: NodeType = BuiltinNodeTypes.VARIABLE_AGGREGATOR - output_type: str - variables: list[list[str]] - advanced_settings: AdvancedSettings | None = None diff --git a/api/graphon/nodes/variable_aggregator/variable_aggregator_node.py b/api/graphon/nodes/variable_aggregator/variable_aggregator_node.py deleted file mode 100644 index 71b221e196..0000000000 --- a/api/graphon/nodes/variable_aggregator/variable_aggregator_node.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping - -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.variable_aggregator.entities import VariableAggregatorNodeData -from graphon.variables.segments import Segment - - -class VariableAggregatorNode(Node[VariableAggregatorNodeData]): - node_type = BuiltinNodeTypes.VARIABLE_AGGREGATOR - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - # Get variables - outputs: dict[str, Segment | Mapping[str, Segment]] = {} - inputs = {} - - if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: - for selector in self.node_data.variables: - variable = self.graph_runtime_state.variable_pool.get(selector) - if variable is not None: - outputs = {"output": variable} - - inputs = {".".join(selector[1:]): variable.to_object()} - break - else: - for group in self.node_data.advanced_settings.groups: - for selector in group.variables: - variable = self.graph_runtime_state.variable_pool.get(selector) - - if variable is not None: - outputs[group.group_name] = {"output": variable} - inputs[".".join(selector[1:])] = variable.to_object() - break - - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs) diff --git a/api/graphon/nodes/variable_assigner/__init__.py b/api/graphon/nodes/variable_assigner/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/nodes/variable_assigner/common/__init__.py b/api/graphon/nodes/variable_assigner/common/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/nodes/variable_assigner/common/exc.py b/api/graphon/nodes/variable_assigner/common/exc.py deleted file mode 100644 index f8dbedc290..0000000000 --- a/api/graphon/nodes/variable_assigner/common/exc.py +++ /dev/null @@ -1,4 +0,0 @@ -class VariableOperatorNodeError(ValueError): - """Base error type, don't use directly.""" - - pass diff --git a/api/graphon/nodes/variable_assigner/common/helpers.py b/api/graphon/nodes/variable_assigner/common/helpers.py deleted file mode 100644 index 4c30e009f2..0000000000 --- a/api/graphon/nodes/variable_assigner/common/helpers.py +++ /dev/null @@ -1,55 +0,0 @@ -from collections.abc import Mapping, MutableMapping, Sequence -from typing import Any, TypeVar - -from pydantic import BaseModel - -from graphon.variables import Segment -from graphon.variables.consts import SELECTORS_LENGTH -from graphon.variables.types import SegmentType - -# Use double underscore (`__`) prefix for internal variables -# to minimize risk of collision with user-defined variable names. -_UPDATED_VARIABLES_KEY = "__updated_variables" - - -class UpdatedVariable(BaseModel): - name: str - selector: Sequence[str] - value_type: SegmentType - new_value: Any = None - - -_T = TypeVar("_T", bound=MutableMapping[str, Any]) - - -def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable: - if len(selector) < SELECTORS_LENGTH: - raise Exception("selector too short") - _, var_name = selector[:2] - return UpdatedVariable( - name=var_name, - selector=list(selector[:2]), - value_type=seg.value_type, - new_value=seg.value, - ) - - -def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T: - m[_UPDATED_VARIABLES_KEY] = updates - return m - - -def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None: - updated_values = m.get(_UPDATED_VARIABLES_KEY, None) - if updated_values is None: - return None - result = [] - for items in updated_values: - if isinstance(items, UpdatedVariable): - result.append(items) - elif isinstance(items, dict): - items = UpdatedVariable.model_validate(items) - result.append(items) - else: - raise TypeError(f"Invalid updated variable: {items}, type={type(items)}") - return result diff --git a/api/graphon/nodes/variable_assigner/v1/__init__.py b/api/graphon/nodes/variable_assigner/v1/__init__.py deleted file mode 100644 index 7eb1428e50..0000000000 --- a/api/graphon/nodes/variable_assigner/v1/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .node import VariableAssignerNode - -__all__ = ["VariableAssignerNode"] diff --git a/api/graphon/nodes/variable_assigner/v1/node.py b/api/graphon/nodes/variable_assigner/v1/node.py deleted file mode 100644 index 19ded5f123..0000000000 --- a/api/graphon/nodes/variable_assigner/v1/node.py +++ /dev/null @@ -1,106 +0,0 @@ -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast - -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent, VariableUpdatedEvent -from graphon.nodes.base.node import Node -from graphon.nodes.variable_assigner.common import helpers as common_helpers -from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from graphon.variables import SegmentType, Variable, VariableBase - -from .node_data import VariableAssignerData, WriteMode - -if TYPE_CHECKING: - from graphon.runtime import GraphRuntimeState - - -class VariableAssignerNode(Node[VariableAssignerData]): - node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: - """ - Check if this Variable Assigner node blocks the output of specific variables. - - Returns True if this node updates any of the requested conversation variables. - """ - assigned_selector = tuple(self.node_data.assigned_variable_selector) - return assigned_selector in variable_selectors - - @classmethod - def version(cls) -> str: - return "1" - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: VariableAssignerData, - ) -> Mapping[str, Sequence[str]]: - mapping = {} - selector_key = ".".join(node_data.assigned_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.assigned_variable_selector - - selector_key = ".".join(node_data.input_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.input_variable_selector - return mapping - - def _run(self) -> Generator[NodeEventBase, None, None]: - assigned_variable_selector = self.node_data.assigned_variable_selector - # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) - if not isinstance(original_variable, VariableBase): - raise VariableOperatorNodeError("assigned variable not found") - - match self.node_data.write_mode: - case WriteMode.OVER_WRITE: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) - if not income_value: - raise VariableOperatorNodeError("input value not found") - updated_variable = original_variable.model_copy(update={"value": income_value.value}) - - case WriteMode.APPEND: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) - if not income_value: - raise VariableOperatorNodeError("input value not found") - updated_value = original_variable.value + [income_value.value] - updated_variable = original_variable.model_copy(update={"value": updated_value}) - - case WriteMode.CLEAR: - income_value = SegmentType.get_zero_value(original_variable.value_type) - updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) - - updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] - yield VariableUpdatedEvent(variable=cast(Variable, updated_variable)) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={ - "value": income_value.to_object(), - }, - # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, - # we still set `output_variables` as a list to ensure the schema of output is - # compatible with `v2.VariableAssignerNode`. - process_data=common_helpers.set_updated_variables({}, updated_variables), - outputs={}, - ) - ) diff --git a/api/graphon/nodes/variable_assigner/v1/node_data.py b/api/graphon/nodes/variable_assigner/v1/node_data.py deleted file mode 100644 index 4f630bc76c..0000000000 --- a/api/graphon/nodes/variable_assigner/v1/node_data.py +++ /dev/null @@ -1,18 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - - -class WriteMode(StrEnum): - OVER_WRITE = "over-write" - APPEND = "append" - CLEAR = "clear" - - -class VariableAssignerData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER - assigned_variable_selector: Sequence[str] - write_mode: WriteMode - input_variable_selector: Sequence[str] diff --git a/api/graphon/nodes/variable_assigner/v2/__init__.py b/api/graphon/nodes/variable_assigner/v2/__init__.py deleted file mode 100644 index 7eb1428e50..0000000000 --- a/api/graphon/nodes/variable_assigner/v2/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .node import VariableAssignerNode - -__all__ = ["VariableAssignerNode"] diff --git a/api/graphon/nodes/variable_assigner/v2/entities.py b/api/graphon/nodes/variable_assigner/v2/entities.py deleted file mode 100644 index d1c68c8e8c..0000000000 --- a/api/graphon/nodes/variable_assigner/v2/entities.py +++ /dev/null @@ -1,28 +0,0 @@ -from collections.abc import Sequence -from typing import Any - -from pydantic import BaseModel, Field - -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType - -from .enums import InputType, Operation - - -class VariableOperationItem(BaseModel): - variable_selector: Sequence[str] - input_type: InputType - operation: Operation - # NOTE(QuantumGhost): The `value` field serves multiple purposes depending on context: - # - # 1. For CONSTANT input_type: Contains the literal value to be used in the operation. - # 2. For VARIABLE input_type: Initially contains the selector of the source variable. - # 3. During the variable updating procedure: The `value` field is reassigned to hold - # the resolved actual value that will be applied to the target variable. - value: Any = None - - -class VariableAssignerNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER - version: str = "2" - items: Sequence[VariableOperationItem] = Field(default_factory=list) diff --git a/api/graphon/nodes/variable_assigner/v2/enums.py b/api/graphon/nodes/variable_assigner/v2/enums.py deleted file mode 100644 index 291b1208d4..0000000000 --- a/api/graphon/nodes/variable_assigner/v2/enums.py +++ /dev/null @@ -1,20 +0,0 @@ -from enum import StrEnum - - -class Operation(StrEnum): - OVER_WRITE = "over-write" - CLEAR = "clear" - APPEND = "append" - EXTEND = "extend" - SET = "set" - ADD = "+=" - SUBTRACT = "-=" - MULTIPLY = "*=" - DIVIDE = "/=" - REMOVE_FIRST = "remove-first" - REMOVE_LAST = "remove-last" - - -class InputType(StrEnum): - VARIABLE = "variable" - CONSTANT = "constant" diff --git a/api/graphon/nodes/variable_assigner/v2/exc.py b/api/graphon/nodes/variable_assigner/v2/exc.py deleted file mode 100644 index 90d7648574..0000000000 --- a/api/graphon/nodes/variable_assigner/v2/exc.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Sequence -from typing import Any - -from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError - -from .enums import InputType, Operation - - -class OperationNotSupportedError(VariableOperatorNodeError): - def __init__(self, *, operation: Operation, variable_type: str): - super().__init__(f"Operation {operation} is not supported for type {variable_type}") - - -class InputTypeNotSupportedError(VariableOperatorNodeError): - def __init__(self, *, input_type: InputType, operation: Operation): - super().__init__(f"Input type {input_type} is not supported for operation {operation}") - - -class VariableNotFoundError(VariableOperatorNodeError): - def __init__(self, *, variable_selector: Sequence[str]): - super().__init__(f"Variable {variable_selector} not found") - - -class InvalidInputValueError(VariableOperatorNodeError): - def __init__(self, *, value: Any): - super().__init__(f"Invalid input value {value}") - - -class ConversationIDNotFoundError(VariableOperatorNodeError): - def __init__(self): - super().__init__("conversation_id not found") - - -class InvalidDataError(VariableOperatorNodeError): - def __init__(self, message: str): - super().__init__(message) diff --git a/api/graphon/nodes/variable_assigner/v2/helpers.py b/api/graphon/nodes/variable_assigner/v2/helpers.py deleted file mode 100644 index ebc6c79476..0000000000 --- a/api/graphon/nodes/variable_assigner/v2/helpers.py +++ /dev/null @@ -1,98 +0,0 @@ -from typing import Any - -from graphon.variables import SegmentType - -from .enums import Operation - - -def is_operation_supported(*, variable_type: SegmentType, operation: Operation): - match operation: - case Operation.OVER_WRITE | Operation.CLEAR: - return True - case Operation.SET: - return variable_type in { - SegmentType.OBJECT, - SegmentType.STRING, - SegmentType.NUMBER, - SegmentType.INTEGER, - SegmentType.FLOAT, - SegmentType.BOOLEAN, - } - case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE: - # Only number variable can be added, subtracted, multiplied or divided - return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT} - case Operation.APPEND | Operation.EXTEND | Operation.REMOVE_FIRST | Operation.REMOVE_LAST: - # Only array variable can be appended or extended - # Only array variable can have elements removed - return variable_type.is_array_type() - - -def is_variable_input_supported(*, operation: Operation): - if operation in {Operation.SET, Operation.ADD, Operation.SUBTRACT, Operation.MULTIPLY, Operation.DIVIDE}: - return False - return True - - -def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation): - match variable_type: - case SegmentType.STRING | SegmentType.OBJECT | SegmentType.BOOLEAN: - return operation in {Operation.OVER_WRITE, Operation.SET} - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return operation in { - Operation.OVER_WRITE, - Operation.SET, - Operation.ADD, - Operation.SUBTRACT, - Operation.MULTIPLY, - Operation.DIVIDE, - } - case _: - return False - - -def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any): - if operation in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST}: - return True - match variable_type: - case SegmentType.STRING: - return isinstance(value, str) - - case SegmentType.BOOLEAN: - return isinstance(value, bool) - - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - if not isinstance(value, int | float): - return False - if operation == Operation.DIVIDE and value == 0: - return False - return True - - case SegmentType.OBJECT: - return isinstance(value, dict) - - # Array & Append - case SegmentType.ARRAY_ANY if operation == Operation.APPEND: - return isinstance(value, str | float | int | dict) - case SegmentType.ARRAY_STRING if operation == Operation.APPEND: - return isinstance(value, str) - case SegmentType.ARRAY_NUMBER if operation == Operation.APPEND: - return isinstance(value, int | float) - case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND: - return isinstance(value, dict) - case SegmentType.ARRAY_BOOLEAN if operation == Operation.APPEND: - return isinstance(value, bool) - - # Array & Extend / Overwrite - case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, str | float | int | dict) for item in value) - case SegmentType.ARRAY_STRING if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, str) for item in value) - case SegmentType.ARRAY_NUMBER if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, int | float) for item in value) - case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, dict) for item in value) - case SegmentType.ARRAY_BOOLEAN if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, bool) for item in value) - - case _: - return False diff --git a/api/graphon/nodes/variable_assigner/v2/node.py b/api/graphon/nodes/variable_assigner/v2/node.py deleted file mode 100644 index 887bd1b604..0000000000 --- a/api/graphon/nodes/variable_assigner/v2/node.py +++ /dev/null @@ -1,257 +0,0 @@ -import json -from collections.abc import Generator, Mapping, MutableMapping, Sequence -from typing import TYPE_CHECKING, Any, cast - -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent, VariableUpdatedEvent -from graphon.nodes.base.node import Node -from graphon.nodes.variable_assigner.common import helpers as common_helpers -from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from graphon.variables import SegmentType, Variable, VariableBase -from graphon.variables.consts import SELECTORS_LENGTH - -from . import helpers -from .entities import VariableAssignerNodeData, VariableOperationItem -from .enums import InputType, Operation -from .exc import ( - InputTypeNotSupportedError, - InvalidDataError, - InvalidInputValueError, - OperationNotSupportedError, - VariableNotFoundError, -) - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState - - -def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): - selector_str = ".".join(item.variable_selector) - key = f"{node_id}.#{selector_str}#" - mapping[key] = item.variable_selector - - -def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): - # Keep this in sync with the logic in _run methods... - if item.input_type != InputType.VARIABLE: - return - selector = item.value - if not isinstance(selector, list): - raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}") - if len(selector) < SELECTORS_LENGTH: - raise InvalidDataError(f"selector too short, {node_id=}, {item=}") - selector_str = ".".join(selector) - key = f"{node_id}.#{selector_str}#" - mapping[key] = selector - - -class VariableAssignerNode(Node[VariableAssignerNodeData]): - node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: - """ - Check if this Variable Assigner node blocks the output of specific variables. - - Returns True if this node updates any of the requested conversation variables. - """ - # Check each item in this Variable Assigner node - for item in self.node_data.items: - # Convert the item's variable_selector to tuple for comparison - item_selector_tuple = tuple(item.variable_selector) - - # Check if this item updates any of the requested variables - if item_selector_tuple in variable_selectors: - return True - - return False - - @classmethod - def version(cls) -> str: - return "2" - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: VariableAssignerNodeData, - ) -> Mapping[str, Sequence[str]]: - var_mapping: dict[str, Sequence[str]] = {} - for item in node_data.items: - _target_mapping_from_item(var_mapping, node_id, item) - _source_mapping_from_item(var_mapping, node_id, item) - return var_mapping - - def _run(self) -> Generator[NodeEventBase, None, None]: - inputs = self.node_data.model_dump() - process_data: dict[str, Any] = {} - # NOTE: This node has no outputs - updated_variable_selectors: list[Sequence[str]] = [] - # Preserve intra-node read-after-write behavior without mutating the shared pool - # until the engine processes the emitted VariableUpdatedEvent instances. - working_variable_pool = self.graph_runtime_state.variable_pool.model_copy(deep=True) - - try: - for item in self.node_data.items: - variable = working_variable_pool.get(item.variable_selector) - - # ==================== Validation Part - - # Check if variable exists - if not isinstance(variable, VariableBase): - raise VariableNotFoundError(variable_selector=item.variable_selector) - - # Check if operation is supported - if not helpers.is_operation_supported(variable_type=variable.value_type, operation=item.operation): - raise OperationNotSupportedError(operation=item.operation, variable_type=variable.value_type) - - # Check if variable input is supported - if item.input_type == InputType.VARIABLE and not helpers.is_variable_input_supported( - operation=item.operation - ): - raise InputTypeNotSupportedError(input_type=InputType.VARIABLE, operation=item.operation) - - # Check if constant input is supported - if item.input_type == InputType.CONSTANT and not helpers.is_constant_input_supported( - variable_type=variable.value_type, operation=item.operation - ): - raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation) - - # Get value from variable pool - input_value = item.value - if ( - item.input_type == InputType.VARIABLE - and item.operation not in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST} - and item.value is not None - ): - value = working_variable_pool.get(item.value) - if value is None: - raise VariableNotFoundError(variable_selector=item.value) - # Skip if value is NoneSegment - if value.value_type == SegmentType.NONE: - continue - input_value = value.value - - # If set string / bytes / bytearray to object, try convert string to object. - if ( - item.operation == Operation.SET - and variable.value_type == SegmentType.OBJECT - and isinstance(input_value, str | bytes | bytearray) - ): - try: - input_value = json.loads(input_value) - except json.JSONDecodeError: - raise InvalidInputValueError(value=input_value) - - # Check if input value is valid - if not helpers.is_input_value_valid( - variable_type=variable.value_type, operation=item.operation, value=input_value - ): - raise InvalidInputValueError(value=input_value) - - # ==================== Execution Part - - updated_value = self._handle_item( - variable=variable, - operation=item.operation, - value=input_value, - ) - updated_variable = variable.model_copy(update={"value": updated_value}) - working_variable_pool.add(updated_variable.selector, updated_variable) - updated_variable_selectors.append(updated_variable.selector) - except VariableOperatorNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - error=str(e), - ) - ) - return - - # The `updated_variable_selectors` is a list contains list[str] which not hashable, - # remove duplicated items while preserving the first update order. - updated_variable_selectors = list(dict.fromkeys(map(tuple, updated_variable_selectors))) - - for selector in updated_variable_selectors: - variable = working_variable_pool.get(selector) - if not isinstance(variable, VariableBase): - raise VariableNotFoundError(variable_selector=selector) - process_data[variable.name] = variable.value - - updated_variables = [ - common_helpers.variable_to_processed_data(selector, seg) - for selector in updated_variable_selectors - if (seg := working_variable_pool.get(selector)) is not None - ] - - process_data = common_helpers.set_updated_variables(process_data, updated_variables) - for selector in updated_variable_selectors: - variable = working_variable_pool.get(selector) - if not isinstance(variable, VariableBase): - raise VariableNotFoundError(variable_selector=selector) - yield VariableUpdatedEvent(variable=cast(Variable, variable)) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={}, - ) - ) - - def _handle_item( - self, - *, - variable: VariableBase, - operation: Operation, - value: Any, - ): - match operation: - case Operation.OVER_WRITE: - return value - case Operation.CLEAR: - return SegmentType.get_zero_value(variable.value_type).to_object() - case Operation.APPEND: - return variable.value + [value] - case Operation.EXTEND: - return variable.value + value - case Operation.SET: - return value - case Operation.ADD: - return variable.value + value - case Operation.SUBTRACT: - return variable.value - value - case Operation.MULTIPLY: - return variable.value * value - case Operation.DIVIDE: - return variable.value / value - case Operation.REMOVE_FIRST: - # If array is empty, do nothing - if not variable.value: - return variable.value - return variable.value[1:] - case Operation.REMOVE_LAST: - # If array is empty, do nothing - if not variable.value: - return variable.value - return variable.value[:-1] diff --git a/api/graphon/prompt_entities.py b/api/graphon/prompt_entities.py deleted file mode 100644 index 2b8b106c6c..0000000000 --- a/api/graphon/prompt_entities.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel - -from graphon.model_runtime.entities.message_entities import PromptMessageRole - - -class ChatModelMessage(BaseModel): - """Graph-owned chat prompt template message.""" - - text: str - role: PromptMessageRole - edition_type: Literal["basic", "jinja2"] | None = None - - -class CompletionModelPromptTemplate(BaseModel): - """Graph-owned completion prompt template.""" - - text: str - edition_type: Literal["basic", "jinja2"] | None = None - - -class MemoryConfig(BaseModel): - """Graph-owned memory configuration for prompt assembly.""" - - class RolePrefix(BaseModel): - """Role labels used when serializing completion-model histories.""" - - user: str - assistant: str - - class WindowConfig(BaseModel): - """History windowing controls.""" - - enabled: bool - size: int | None = None - - role_prefix: RolePrefix | None = None - window: WindowConfig - query_prompt_template: str | None = None - - -__all__ = [ - "ChatModelMessage", - "CompletionModelPromptTemplate", - "MemoryConfig", -] diff --git a/api/graphon/runtime/__init__.py b/api/graphon/runtime/__init__.py deleted file mode 100644 index adca07e59a..0000000000 --- a/api/graphon/runtime/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .graph_runtime_state import ( - ChildEngineBuilderNotConfiguredError, - ChildEngineError, - ChildGraphNotFoundError, - GraphRuntimeState, -) -from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool -from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper -from .variable_pool import VariablePool, VariableValue - -__all__ = [ - "ChildEngineBuilderNotConfiguredError", - "ChildEngineError", - "ChildGraphNotFoundError", - "GraphRuntimeState", - "ReadOnlyGraphRuntimeState", - "ReadOnlyGraphRuntimeStateWrapper", - "ReadOnlyVariablePool", - "ReadOnlyVariablePoolWrapper", - "VariablePool", - "VariableValue", -] diff --git a/api/graphon/runtime/graph_runtime_state.py b/api/graphon/runtime/graph_runtime_state.py deleted file mode 100644 index 6e4ed202b5..0000000000 --- a/api/graphon/runtime/graph_runtime_state.py +++ /dev/null @@ -1,693 +0,0 @@ -from __future__ import annotations - -import importlib -import json -from collections.abc import Mapping, Sequence -from contextlib import AbstractContextManager, nullcontext -from copy import deepcopy -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Protocol - -from pydantic import BaseModel, Field -from pydantic.json import pydantic_encoder - -from graphon.enums import NodeExecutionType, NodeState, NodeType -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime.variable_pool import VariablePool - -if TYPE_CHECKING: - from graphon.entities import GraphInitParams - from graphon.entities.pause_reason import PauseReason - - -class ReadyQueueProtocol(Protocol): - """Structural interface required from ready queue implementations.""" - - def put(self, item: str) -> None: - """Enqueue the identifier of a node that is ready to run.""" - ... - - def get(self, timeout: float | None = None) -> str: - """Return the next node identifier, blocking until available or timeout expires.""" - ... - - def task_done(self) -> None: - """Signal that the most recently dequeued node has completed processing.""" - ... - - def empty(self) -> bool: - """Return True when the queue contains no pending nodes.""" - ... - - def qsize(self) -> int: - """Approximate the number of pending nodes awaiting execution.""" - ... - - def dumps(self) -> str: - """Serialize the queue contents for persistence.""" - ... - - def loads(self, data: str) -> None: - """Restore the queue contents from a serialized payload.""" - ... - - -class GraphExecutionProtocol(Protocol): - """Structural interface for graph execution aggregate. - - Defines the minimal set of attributes and methods required from a GraphExecution entity - for runtime orchestration and state management. - """ - - workflow_id: str - started: bool - completed: bool - aborted: bool - error: Exception | None - exceptions_count: int - pause_reasons: list[PauseReason] - - def start(self) -> None: - """Transition execution into the running state.""" - ... - - def complete(self) -> None: - """Mark execution as successfully completed.""" - ... - - def abort(self, reason: str) -> None: - """Abort execution in response to an external stop request.""" - ... - - def fail(self, error: Exception) -> None: - """Record an unrecoverable error and end execution.""" - ... - - def dumps(self) -> str: - """Serialize execution state into a JSON payload.""" - ... - - def loads(self, data: str) -> None: - """Restore execution state from a previously serialized payload.""" - ... - - -class ResponseStreamCoordinatorProtocol(Protocol): - """Structural interface for response stream coordinator.""" - - def register(self, response_node_id: str) -> None: - """Register a response node so its outputs can be streamed.""" - ... - - def loads(self, data: str) -> None: - """Restore coordinator state from a serialized payload.""" - ... - - def dumps(self) -> str: - """Serialize coordinator state for persistence.""" - ... - - -class NodeProtocol(Protocol): - """Structural interface for graph nodes.""" - - id: str - state: NodeState - execution_type: NodeExecutionType - node_type: ClassVar[NodeType] - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ... - - -class EdgeProtocol(Protocol): - id: str - state: NodeState - tail: str - head: str - source_handle: str - - -class GraphProtocol(Protocol): - """Structural interface required from graph instances attached to the runtime state.""" - - nodes: Mapping[str, NodeProtocol] - edges: Mapping[str, EdgeProtocol] - root_node: NodeProtocol - - def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... - - -class ChildGraphEngineBuilderProtocol(Protocol): - def build_child_engine( - self, - *, - workflow_id: str, - graph_init_params: GraphInitParams, - parent_graph_runtime_state: GraphRuntimeState, - root_node_id: str, - variable_pool: VariablePool | None = None, - ) -> Any: ... - - -class ChildEngineError(ValueError): - """Base error type for child-engine creation failures.""" - - -class ChildEngineBuilderNotConfiguredError(ChildEngineError): - """Raised when child-engine creation is requested without a bound builder.""" - - -class ChildGraphNotFoundError(ChildEngineError): - """Raised when the requested child graph entry point cannot be resolved.""" - - -class _GraphStateSnapshot(BaseModel): - """Serializable graph state snapshot for node/edge states.""" - - nodes: dict[str, NodeState] = Field(default_factory=dict) - edges: dict[str, NodeState] = Field(default_factory=dict) - - -@dataclass(slots=True) -class _GraphRuntimeStateSnapshot: - """Immutable view of a serialized runtime state snapshot.""" - - start_at: float - total_tokens: int - node_run_steps: int - llm_usage: LLMUsage - outputs: dict[str, Any] - variable_pool: VariablePool - has_variable_pool: bool - ready_queue_dump: str | None - graph_execution_dump: str | None - response_coordinator_dump: str | None - paused_nodes: tuple[str, ...] - deferred_nodes: tuple[str, ...] - graph_node_states: dict[str, NodeState] - graph_edge_states: dict[str, NodeState] - - -class GraphRuntimeState: - """Mutable runtime state shared across graph execution components. - - `GraphRuntimeState` encapsulates the runtime state of workflow execution, - including scheduling details, variable values, and timing information. - - Values that are initialized prior to workflow execution and remain constant - throughout the execution should be part of `GraphInitParams` instead. - """ - - def __init__( - self, - *, - variable_pool: VariablePool, - start_at: float, - total_tokens: int = 0, - llm_usage: LLMUsage | None = None, - outputs: dict[str, object] | None = None, - node_run_steps: int = 0, - ready_queue: ReadyQueueProtocol | None = None, - graph_execution: GraphExecutionProtocol | None = None, - response_coordinator: ResponseStreamCoordinatorProtocol | None = None, - graph: GraphProtocol | None = None, - execution_context: AbstractContextManager[object] | None = None, - ) -> None: - self._variable_pool = variable_pool - self._start_at = start_at - - if total_tokens < 0: - raise ValueError("total_tokens must be non-negative") - self._total_tokens = total_tokens - - self._llm_usage = (llm_usage or LLMUsage.empty_usage()).model_copy() - self._outputs = deepcopy(outputs) if outputs is not None else {} - - if node_run_steps < 0: - raise ValueError("node_run_steps must be non-negative") - self._node_run_steps = node_run_steps - - self._graph: GraphProtocol | None = None - - self._ready_queue = ready_queue - self._graph_execution = graph_execution - self._response_coordinator = response_coordinator - # Application code injects this when worker threads must restore request - # or framework-local state. It is intentionally excluded from snapshots. - self._execution_context = execution_context if execution_context is not None else nullcontext(None) - self._pending_response_coordinator_dump: str | None = None - self._pending_graph_execution_workflow_id: str | None = None - self._paused_nodes: set[str] = set() - self._deferred_nodes: set[str] = set() - self._child_engine_builder: ChildGraphEngineBuilderProtocol | None = None - - # Node and edges states needed to be restored into - # graph object. - # - # These two fields are non-None only when resuming from a snapshot. - # Once the graph is attached, these two fields will be set to None. - self._pending_graph_node_states: dict[str, NodeState] | None = None - self._pending_graph_edge_states: dict[str, NodeState] | None = None - - if graph is not None: - self.attach_graph(graph) - - # ------------------------------------------------------------------ - # Context binding helpers - # ------------------------------------------------------------------ - def attach_graph(self, graph: GraphProtocol) -> None: - """Attach the materialized graph to the runtime state.""" - if self._graph is not None and self._graph is not graph: - raise ValueError("GraphRuntimeState already attached to a different graph instance") - - self._graph = graph - - if self._response_coordinator is None: - self._response_coordinator = self._build_response_coordinator(graph) - - if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None: - self._response_coordinator.loads(self._pending_response_coordinator_dump) - self._pending_response_coordinator_dump = None - self._apply_pending_graph_state() - - def configure(self, *, graph: GraphProtocol | None = None) -> None: - """Ensure core collaborators are initialized with the provided context.""" - if graph is not None: - self.attach_graph(graph) - - # Ensure collaborators are instantiated - _ = self.ready_queue - _ = self.graph_execution - if self._graph is not None: - _ = self.response_coordinator - - def bind_child_engine_builder(self, builder: ChildGraphEngineBuilderProtocol) -> None: - self._child_engine_builder = builder - - def create_child_engine( - self, - *, - workflow_id: str, - graph_init_params: GraphInitParams, - root_node_id: str, - variable_pool: VariablePool | None = None, - ) -> Any: - """Create a child graph engine that derives its runtime state from the parent.""" - if self._child_engine_builder is None: - raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.") - - return self._child_engine_builder.build_child_engine( - workflow_id=workflow_id, - graph_init_params=graph_init_params, - parent_graph_runtime_state=self, - root_node_id=root_node_id, - variable_pool=variable_pool, - ) - - # ------------------------------------------------------------------ - # Primary collaborators - # ------------------------------------------------------------------ - @property - def variable_pool(self) -> VariablePool: - return self._variable_pool - - @property - def ready_queue(self) -> ReadyQueueProtocol: - if self._ready_queue is None: - self._ready_queue = self._build_ready_queue() - return self._ready_queue - - @property - def graph_execution(self) -> GraphExecutionProtocol: - if self._graph_execution is None: - self._graph_execution = self._build_graph_execution() - return self._graph_execution - - @property - def response_coordinator(self) -> ResponseStreamCoordinatorProtocol: - if self._response_coordinator is None: - if self._graph is None: - raise ValueError("Graph must be attached before accessing response coordinator") - self._response_coordinator = self._build_response_coordinator(self._graph) - return self._response_coordinator - - @property - def execution_context(self) -> AbstractContextManager[object]: - return self._execution_context - - @execution_context.setter - def execution_context(self, value: AbstractContextManager[object] | None) -> None: - self._execution_context = value if value is not None else nullcontext(None) - - # ------------------------------------------------------------------ - # Scalar state - # ------------------------------------------------------------------ - @property - def start_at(self) -> float: - return self._start_at - - @start_at.setter - def start_at(self, value: float) -> None: - self._start_at = value - - @property - def total_tokens(self) -> int: - return self._total_tokens - - @total_tokens.setter - def total_tokens(self, value: int) -> None: - if value < 0: - raise ValueError("total_tokens must be non-negative") - self._total_tokens = value - - @property - def llm_usage(self) -> LLMUsage: - return self._llm_usage.model_copy() - - @llm_usage.setter - def llm_usage(self, value: LLMUsage) -> None: - self._llm_usage = value.model_copy() - - @property - def outputs(self) -> dict[str, Any]: - return deepcopy(self._outputs) - - @outputs.setter - def outputs(self, value: dict[str, Any]) -> None: - self._outputs = deepcopy(value) - - def set_output(self, key: str, value: object) -> None: - self._outputs[key] = deepcopy(value) - - def get_output(self, key: str, default: object = None) -> object: - return deepcopy(self._outputs.get(key, default)) - - def update_outputs(self, updates: dict[str, object]) -> None: - for key, value in updates.items(): - self._outputs[key] = deepcopy(value) - - @property - def node_run_steps(self) -> int: - return self._node_run_steps - - @node_run_steps.setter - def node_run_steps(self, value: int) -> None: - if value < 0: - raise ValueError("node_run_steps must be non-negative") - self._node_run_steps = value - - def increment_node_run_steps(self) -> None: - self._node_run_steps += 1 - - def add_tokens(self, tokens: int) -> None: - if tokens < 0: - raise ValueError("tokens must be non-negative") - self._total_tokens += tokens - - # ------------------------------------------------------------------ - # Serialization - # ------------------------------------------------------------------ - def dumps(self) -> str: - """Serialize runtime state into a JSON string.""" - - snapshot: dict[str, Any] = { - "version": "1.0", - "start_at": self._start_at, - "total_tokens": self._total_tokens, - "node_run_steps": self._node_run_steps, - "llm_usage": self._llm_usage.model_dump(mode="json"), - "outputs": self.outputs, - "variable_pool": self.variable_pool.model_dump(mode="json"), - "ready_queue": self.ready_queue.dumps(), - "graph_execution": self.graph_execution.dumps(), - "paused_nodes": list(self._paused_nodes), - "deferred_nodes": list(self._deferred_nodes), - } - - graph_state = self._snapshot_graph_state() - if graph_state is not None: - snapshot["graph_state"] = graph_state - - if self._response_coordinator is not None and self._graph is not None: - snapshot["response_coordinator"] = self._response_coordinator.dumps() - - return json.dumps(snapshot, default=pydantic_encoder) - - @classmethod - def from_snapshot(cls, data: str | Mapping[str, Any]) -> GraphRuntimeState: - """Restore runtime state from a serialized snapshot.""" - - snapshot = cls._parse_snapshot_payload(data) - - state = cls( - variable_pool=snapshot.variable_pool, - start_at=snapshot.start_at, - total_tokens=snapshot.total_tokens, - llm_usage=snapshot.llm_usage, - outputs=snapshot.outputs, - node_run_steps=snapshot.node_run_steps, - ) - state._apply_snapshot(snapshot) - return state - - def loads(self, data: str | Mapping[str, Any]) -> None: - """Restore runtime state from a serialized snapshot (legacy API).""" - - snapshot = self._parse_snapshot_payload(data) - self._apply_snapshot(snapshot) - - def register_paused_node(self, node_id: str) -> None: - """Record a node that should resume when execution is continued.""" - - self._paused_nodes.add(node_id) - - def get_paused_nodes(self) -> list[str]: - """Retrieve the list of paused nodes without mutating internal state.""" - - return list(self._paused_nodes) - - def consume_paused_nodes(self) -> list[str]: - """Retrieve and clear the list of paused nodes awaiting resume.""" - - nodes = list(self._paused_nodes) - self._paused_nodes.clear() - return nodes - - def register_deferred_node(self, node_id: str) -> None: - """Record a node that became ready during pause and should resume later.""" - - self._deferred_nodes.add(node_id) - - def get_deferred_nodes(self) -> list[str]: - """Retrieve deferred nodes without mutating internal state.""" - - return list(self._deferred_nodes) - - def consume_deferred_nodes(self) -> list[str]: - """Retrieve and clear deferred nodes awaiting resume.""" - - nodes = list(self._deferred_nodes) - self._deferred_nodes.clear() - return nodes - - # ------------------------------------------------------------------ - # Builders - # ------------------------------------------------------------------ - def _build_ready_queue(self) -> ReadyQueueProtocol: - # Import lazily to avoid breaching architecture boundaries enforced by import-linter. - module = importlib.import_module("graphon.graph_engine.ready_queue") - in_memory_cls = module.InMemoryReadyQueue - return in_memory_cls() - - def _build_graph_execution(self) -> GraphExecutionProtocol: - # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("graphon.graph_engine.domain.graph_execution") - graph_execution_cls = module.GraphExecution - workflow_id = self._pending_graph_execution_workflow_id or "" - self._pending_graph_execution_workflow_id = None - return graph_execution_cls(workflow_id=workflow_id) # type: ignore[invalid-return-type] - - def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol: - # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("graphon.graph_engine.response_coordinator") - coordinator_cls = module.ResponseStreamCoordinator - return coordinator_cls(variable_pool=self.variable_pool, graph=graph) - - # ------------------------------------------------------------------ - # Snapshot helpers - # ------------------------------------------------------------------ - @classmethod - def _parse_snapshot_payload(cls, data: str | Mapping[str, Any]) -> _GraphRuntimeStateSnapshot: - payload: dict[str, Any] - if isinstance(data, str): - payload = json.loads(data) - else: - payload = dict(data) - - version = payload.get("version") - if version != "1.0": - raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}") - - start_at = float(payload.get("start_at", 0.0)) - - total_tokens = int(payload.get("total_tokens", 0)) - if total_tokens < 0: - raise ValueError("total_tokens must be non-negative") - - node_run_steps = int(payload.get("node_run_steps", 0)) - if node_run_steps < 0: - raise ValueError("node_run_steps must be non-negative") - - llm_usage_payload = payload.get("llm_usage", {}) - llm_usage = LLMUsage.model_validate(llm_usage_payload) - - outputs_payload = deepcopy(payload.get("outputs", {})) - - variable_pool_payload = payload.get("variable_pool") - has_variable_pool = variable_pool_payload is not None - variable_pool = VariablePool.model_validate(variable_pool_payload) if has_variable_pool else VariablePool() - - ready_queue_payload = payload.get("ready_queue") - graph_execution_payload = payload.get("graph_execution") - response_payload = payload.get("response_coordinator") - paused_nodes_payload = payload.get("paused_nodes", []) - deferred_nodes_payload = payload.get("deferred_nodes", []) - graph_state_payload = payload.get("graph_state", {}) or {} - graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes") - graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges") - - return _GraphRuntimeStateSnapshot( - start_at=start_at, - total_tokens=total_tokens, - node_run_steps=node_run_steps, - llm_usage=llm_usage, - outputs=outputs_payload, - variable_pool=variable_pool, - has_variable_pool=has_variable_pool, - ready_queue_dump=ready_queue_payload, - graph_execution_dump=graph_execution_payload, - response_coordinator_dump=response_payload, - paused_nodes=tuple(map(str, paused_nodes_payload)), - deferred_nodes=tuple(map(str, deferred_nodes_payload)), - graph_node_states=graph_node_states, - graph_edge_states=graph_edge_states, - ) - - def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None: - self._start_at = snapshot.start_at - self._total_tokens = snapshot.total_tokens - self._node_run_steps = snapshot.node_run_steps - self._llm_usage = snapshot.llm_usage.model_copy() - self._outputs = deepcopy(snapshot.outputs) - if snapshot.has_variable_pool or self._variable_pool is None: - self._variable_pool = snapshot.variable_pool - - self._restore_ready_queue(snapshot.ready_queue_dump) - self._restore_graph_execution(snapshot.graph_execution_dump) - self._restore_response_coordinator(snapshot.response_coordinator_dump) - self._paused_nodes = set(snapshot.paused_nodes) - self._deferred_nodes = set(snapshot.deferred_nodes) - self._pending_graph_node_states = snapshot.graph_node_states or None - self._pending_graph_edge_states = snapshot.graph_edge_states or None - self._apply_pending_graph_state() - - def _restore_ready_queue(self, payload: str | None) -> None: - if payload is not None: - self._ready_queue = self._build_ready_queue() - self._ready_queue.loads(payload) - else: - self._ready_queue = None - - def _restore_graph_execution(self, payload: str | None) -> None: - self._graph_execution = None - self._pending_graph_execution_workflow_id = None - - if payload is None: - return - - try: - execution_payload = json.loads(payload) - self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id") - except (json.JSONDecodeError, TypeError, AttributeError): - self._pending_graph_execution_workflow_id = None - - self.graph_execution.loads(payload) - - def _restore_response_coordinator(self, payload: str | None) -> None: - if payload is None: - self._pending_response_coordinator_dump = None - self._response_coordinator = None - return - - if self._graph is not None: - self.response_coordinator.loads(payload) - self._pending_response_coordinator_dump = None - return - - self._pending_response_coordinator_dump = payload - self._response_coordinator = None - - def _snapshot_graph_state(self) -> _GraphStateSnapshot: - graph = self._graph - if graph is None: - if self._pending_graph_node_states is None and self._pending_graph_edge_states is None: - return _GraphStateSnapshot() - return _GraphStateSnapshot( - nodes=self._pending_graph_node_states or {}, - edges=self._pending_graph_edge_states or {}, - ) - - nodes = graph.nodes - edges = graph.edges - if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping): - return _GraphStateSnapshot() - - node_states = {} - for node_id, node in nodes.items(): - if not isinstance(node_id, str): - continue - node_states[node_id] = node.state - - edge_states = {} - for edge_id, edge in edges.items(): - if not isinstance(edge_id, str): - continue - edge_states[edge_id] = edge.state - - return _GraphStateSnapshot(nodes=node_states, edges=edge_states) - - def _apply_pending_graph_state(self) -> None: - if self._graph is None: - return - if self._pending_graph_node_states: - for node_id, state in self._pending_graph_node_states.items(): - node = self._graph.nodes.get(node_id) - if node is None: - continue - node.state = state - if self._pending_graph_edge_states: - for edge_id, state in self._pending_graph_edge_states.items(): - edge = self._graph.edges.get(edge_id) - if edge is None: - continue - edge.state = state - - self._pending_graph_node_states = None - self._pending_graph_edge_states = None - - -def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]: - if not isinstance(payload, Mapping): - return {} - raw_map = payload.get(key, {}) - if not isinstance(raw_map, Mapping): - return {} - result: dict[str, NodeState] = {} - for node_id, raw_state in raw_map.items(): - if not isinstance(node_id, str): - continue - try: - result[node_id] = NodeState(str(raw_state)) - except ValueError: - continue - return result diff --git a/api/graphon/runtime/graph_runtime_state_protocol.py b/api/graphon/runtime/graph_runtime_state_protocol.py deleted file mode 100644 index 856625a5d3..0000000000 --- a/api/graphon/runtime/graph_runtime_state_protocol.py +++ /dev/null @@ -1,79 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Protocol - -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.variables.segments import Segment - - -class ReadOnlyVariablePool(Protocol): - """Read-only interface for VariablePool.""" - - def get(self, selector: Sequence[str], /) -> Segment | None: - """Get a variable value (read-only).""" - ... - - def get_all_by_node(self, node_id: str) -> Mapping[str, object]: - """Get all variables for a node (read-only).""" - ... - - def get_by_prefix(self, prefix: str) -> Mapping[str, object]: - """Get all variables stored under a given node prefix (read-only).""" - ... - - -class ReadOnlyGraphRuntimeState(Protocol): - """ - Read-only view of GraphRuntimeState for layers. - - This protocol defines a read-only interface that prevents layers from - modifying the graph runtime state while still allowing observation. - All methods return defensive copies to ensure immutability. - """ - - @property - def variable_pool(self) -> ReadOnlyVariablePool: - """Get read-only access to the variable pool.""" - ... - - @property - def start_at(self) -> float: - """Get the start time (read-only).""" - ... - - @property - def total_tokens(self) -> int: - """Get the total tokens count (read-only).""" - ... - - @property - def llm_usage(self) -> LLMUsage: - """Get a copy of LLM usage info (read-only).""" - ... - - @property - def outputs(self) -> dict[str, Any]: - """Get a defensive copy of outputs (read-only).""" - ... - - @property - def node_run_steps(self) -> int: - """Get the node run steps count (read-only).""" - ... - - @property - def ready_queue_size(self) -> int: - """Get the number of nodes currently in the ready queue.""" - ... - - @property - def exceptions_count(self) -> int: - """Get the number of node execution exceptions recorded.""" - ... - - def get_output(self, key: str, default: Any = None) -> Any: - """Get a single output value (returns a copy).""" - ... - - def dumps(self) -> str: - """Serialize the runtime state into a JSON snapshot (read-only).""" - ... diff --git a/api/graphon/runtime/read_only_wrappers.py b/api/graphon/runtime/read_only_wrappers.py deleted file mode 100644 index aaef255204..0000000000 --- a/api/graphon/runtime/read_only_wrappers.py +++ /dev/null @@ -1,82 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from copy import deepcopy -from typing import Any - -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.variables.segments import Segment - -from .graph_runtime_state import GraphRuntimeState -from .variable_pool import VariablePool - - -class ReadOnlyVariablePoolWrapper: - """Provide defensive, read-only access to ``VariablePool``.""" - - def __init__(self, variable_pool: VariablePool) -> None: - self._variable_pool = variable_pool - - def get(self, selector: Sequence[str], /) -> Segment | None: - """Return a copy of a variable value if present.""" - value = self._variable_pool.get(selector) - return deepcopy(value) if value is not None else None - - def get_all_by_node(self, node_id: str) -> Mapping[str, object]: - """Return a copy of all variables for the specified node.""" - variables: dict[str, object] = {} - if node_id in self._variable_pool.variable_dictionary: - for key, variable in self._variable_pool.variable_dictionary[node_id].items(): - variables[key] = deepcopy(variable.value) - return variables - - def get_by_prefix(self, prefix: str) -> Mapping[str, object]: - """Return a copy of all variables stored under the given prefix.""" - return self._variable_pool.get_by_prefix(prefix) - - -class ReadOnlyGraphRuntimeStateWrapper: - """Expose a defensive, read-only view of ``GraphRuntimeState``.""" - - def __init__(self, state: GraphRuntimeState) -> None: - self._state = state - self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool) - - @property - def variable_pool(self) -> ReadOnlyVariablePoolWrapper: - return self._variable_pool_wrapper - - @property - def start_at(self) -> float: - return self._state.start_at - - @property - def total_tokens(self) -> int: - return self._state.total_tokens - - @property - def llm_usage(self) -> LLMUsage: - return self._state.llm_usage.model_copy() - - @property - def outputs(self) -> dict[str, Any]: - return deepcopy(self._state.outputs) - - @property - def node_run_steps(self) -> int: - return self._state.node_run_steps - - @property - def ready_queue_size(self) -> int: - return self._state.ready_queue.qsize() - - @property - def exceptions_count(self) -> int: - return self._state.graph_execution.exceptions_count - - def get_output(self, key: str, default: Any = None) -> Any: - return self._state.get_output(key, default) - - def dumps(self) -> str: - """Serialize the underlying runtime state for external persistence.""" - return self._state.dumps() diff --git a/api/graphon/runtime/variable_pool.py b/api/graphon/runtime/variable_pool.py deleted file mode 100644 index b44d1a8abe..0000000000 --- a/api/graphon/runtime/variable_pool.py +++ /dev/null @@ -1,279 +0,0 @@ -from __future__ import annotations - -import re -from collections import defaultdict -from collections.abc import Mapping, Sequence -from copy import deepcopy -from typing import Annotated, Any, Union, cast - -from pydantic import BaseModel, Field, model_validator - -from graphon.file import File, FileAttribute, file_manager -from graphon.variables import Segment, SegmentGroup, VariableBase, build_segment, segment_to_variable -from graphon.variables.consts import SELECTORS_LENGTH -from graphon.variables.segments import FileSegment, ObjectSegment -from graphon.variables.variables import RAGPipelineVariableInput, Variable - -VariableValue = Union[str, int, float, dict[str, object], list[object], File] - -VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") - - -def _default_variable_dictionary() -> defaultdict[str, dict[str, Variable]]: - return defaultdict(dict) - - -class VariablePool(BaseModel): - _SYSTEM_VARIABLE_NODE_ID = "sys" - _ENVIRONMENT_VARIABLE_NODE_ID = "env" - _CONVERSATION_VARIABLE_NODE_ID = "conversation" - _RAG_PIPELINE_VARIABLE_NODE_ID = "rag" - - # Variable dictionary is a dictionary for looking up variables by their selector. - # The first element of the selector is the node id, it's the first-level key in the dictionary. - # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the - # elements of the selector except the first one. - variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field( - description="Variables mapping", - default_factory=_default_variable_dictionary, - ) - system_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) - environment_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) - conversation_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) - rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = Field(default_factory=tuple, exclude=True) - user_inputs: Mapping[str, Any] = Field(default_factory=dict, exclude=True) - - @model_validator(mode="after") - def _load_legacy_bootstrap_inputs(self) -> VariablePool: - """ - Accept legacy constructor kwargs that still appear throughout the workflow - layer while keeping serialized state focused on `variable_dictionary`. - """ - - self._ingest_legacy_variables(self.system_variables, node_id=self._SYSTEM_VARIABLE_NODE_ID) - self._ingest_legacy_variables(self.environment_variables, node_id=self._ENVIRONMENT_VARIABLE_NODE_ID) - self._ingest_legacy_variables(self.conversation_variables, node_id=self._CONVERSATION_VARIABLE_NODE_ID) - self._ingest_legacy_rag_variables(self.rag_pipeline_variables) - - # These kwargs are accepted for compatibility but should not affect the - # stable serialized form or model equality. - self.system_variables = () - self.environment_variables = () - self.conversation_variables = () - self.rag_pipeline_variables = () - self.user_inputs = {} - return self - - def _ingest_legacy_variables(self, variables: Sequence[Variable], *, node_id: str) -> None: - for variable in variables: - selector = [node_id, variable.name] - normalized_variable = variable - if list(variable.selector) != selector: - normalized_variable = variable.model_copy(update={"selector": selector}) - self.add(normalized_variable.selector, normalized_variable) - - def _ingest_legacy_rag_variables(self, rag_pipeline_variables: Sequence[RAGPipelineVariableInput]) -> None: - if not rag_pipeline_variables: - return - - values_by_node_id: defaultdict[str, dict[str, Any]] = defaultdict(dict) - for rag_variable_input in rag_pipeline_variables: - values_by_node_id[rag_variable_input.variable.belong_to_node_id][rag_variable_input.variable.variable] = ( - rag_variable_input.value - ) - - for node_id, value in values_by_node_id.items(): - self.add((self._RAG_PIPELINE_VARIABLE_NODE_ID, node_id), value) - - def add(self, selector: Sequence[str], value: Any, /): - """ - Add a variable to the variable pool. - - This method accepts a selector path and a value, converting the value - to a Variable object if necessary before storing it in the pool. - - Args: - selector: A two-element sequence containing [node_id, variable_name]. - The selector must have exactly 2 elements to be valid. - value: The value to store. Can be a Variable, Segment, or any value - that can be converted to a Segment (str, int, float, dict, list, File). - - Raises: - ValueError: If selector length is not exactly 2 elements. - - Note: - While non-Segment values are currently accepted and automatically - converted, it's recommended to pass Segment or Variable objects directly. - """ - if len(selector) != SELECTORS_LENGTH: - raise ValueError( - f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), " - f"got {len(selector)} elements" - ) - - if isinstance(value, VariableBase): - variable = value - elif isinstance(value, Segment): - variable = segment_to_variable(segment=value, selector=selector) - else: - segment = build_segment(value) - variable = segment_to_variable(segment=segment, selector=selector) - - node_id, name = self._selector_to_keys(selector) - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - self.variable_dictionary[node_id][name] = cast(Variable, variable) - - @classmethod - def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]: - return selector[0], selector[1] - - def _has(self, selector: Sequence[str]) -> bool: - node_id, name = self._selector_to_keys(selector) - if node_id not in self.variable_dictionary: - return False - if name not in self.variable_dictionary[node_id]: - return False - return True - - def get(self, selector: Sequence[str], /) -> Segment | None: - """ - Retrieve a variable's value from the pool as a Segment. - - This method supports both simple selectors [node_id, variable_name] and - extended selectors that include attribute access for FileSegment and - ObjectSegment types. - - Args: - selector: A sequence with at least 2 elements: - - [node_id, variable_name]: Returns the full segment - - [node_id, variable_name, attr, ...]: Returns a nested value - from FileSegment (e.g., 'url', 'name') or ObjectSegment - - Returns: - The Segment associated with the selector, or None if not found. - Returns None if selector has fewer than 2 elements. - - Raises: - ValueError: If attempting to access an invalid FileAttribute. - """ - if len(selector) < SELECTORS_LENGTH: - return None - - node_id, name = self._selector_to_keys(selector) - node_map = self.variable_dictionary.get(node_id) - if node_map is None: - return None - - segment: Segment | None = node_map.get(name) - - if segment is None: - return None - - if len(selector) == 2: - return segment - - if isinstance(segment, FileSegment): - attr = selector[2] - # Python support `attr in FileAttribute` after 3.12 - if attr not in {item.value for item in FileAttribute}: - return None - attr = FileAttribute(attr) - attr_value = file_manager.get_attr(file=segment.value, attr=attr) - return build_segment(attr_value) - - # Navigate through nested attributes - result: Any = segment - for attr in selector[2:]: - result = self._extract_value(result) - result = self._get_nested_attribute(result, attr) - if result is None: - return None - - # Return result as Segment - return result if isinstance(result, Segment) else build_segment(result) - - def _extract_value(self, obj: Any): - """Extract the actual value from an ObjectSegment.""" - return obj.value if isinstance(obj, ObjectSegment) else obj - - def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None: - """ - Get a nested attribute from a dictionary-like object. - - Args: - obj: The dictionary-like object to search. - attr: The key to look up. - - Returns: - Segment | None: - The corresponding Segment built from the attribute value if the key exists, - otherwise None. - """ - if not isinstance(obj, dict) or attr not in obj: - return None - return build_segment(obj.get(attr)) - - def remove(self, selector: Sequence[str], /): - """ - Remove variables from the variable pool based on the given selector. - - Args: - selector (Sequence[str]): A sequence of strings representing the selector. - - Returns: - None - """ - if not selector: - return - if len(selector) == 1: - self.variable_dictionary[selector[0]] = {} - return - key, hash_key = self._selector_to_keys(selector) - self.variable_dictionary[key].pop(hash_key, None) - - def convert_template(self, template: str, /): - parts = VARIABLE_PATTERN.split(template) - segments: list[Segment] = [] - for part in filter(lambda x: x, parts): - if "." in part and (variable := self.get(part.split("."))): - segments.append(variable) - else: - segments.append(build_segment(part)) - return SegmentGroup(value=segments) - - def get_file(self, selector: Sequence[str], /) -> FileSegment | None: - segment = self.get(selector) - if isinstance(segment, FileSegment): - return segment - return None - - def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: - """Return a copy of all variables stored under the given node prefix.""" - - nodes = self.variable_dictionary.get(prefix) - if not nodes: - return {} - - result: dict[str, object] = {} - for key, variable in nodes.items(): - value = variable.value - result[key] = deepcopy(value) - - return result - - def flatten(self, *, unprefixed_node_id: str | None = None) -> Mapping[str, object]: - """Return a selector-style snapshot of the entire variable pool.""" - - result: dict[str, object] = {} - for node_id, variables in self.variable_dictionary.items(): - for name, variable in variables.items(): - output_name = name if node_id == unprefixed_node_id else f"{node_id}.{name}" - result[output_name] = deepcopy(variable.value) - - return result - - @classmethod - def empty(cls) -> VariablePool: - """Create an empty variable pool.""" - return cls() diff --git a/api/graphon/template_rendering.py b/api/graphon/template_rendering.py deleted file mode 100644 index 0527e58f6d..0000000000 --- a/api/graphon/template_rendering.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from collections.abc import Mapping -from typing import Any - - -class TemplateRenderError(ValueError): - """Raised when rendering a template fails.""" - - -class Jinja2TemplateRenderer(ABC): - """Nominal renderer contract for Jinja2 template rendering in graph nodes.""" - - @abstractmethod - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - """Render the template into plain text.""" - raise NotImplementedError diff --git a/api/graphon/utils/__init__.py b/api/graphon/utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/utils/condition/__init__.py b/api/graphon/utils/condition/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/graphon/utils/condition/entities.py b/api/graphon/utils/condition/entities.py deleted file mode 100644 index 77a214571a..0000000000 --- a/api/graphon/utils/condition/entities.py +++ /dev/null @@ -1,49 +0,0 @@ -from collections.abc import Sequence -from typing import Literal - -from pydantic import BaseModel, Field - -SupportedComparisonOperator = Literal[ - # for string or array - "contains", - "not contains", - "start with", - "end with", - "is", - "is not", - "empty", - "not empty", - "in", - "not in", - "all of", - # for number - "=", - "≠", - ">", - "<", - "≥", - "≤", - "null", - "not null", - # for file - "exists", - "not exists", -] - - -class SubCondition(BaseModel): - key: str - comparison_operator: SupportedComparisonOperator - value: str | Sequence[str] | None = None - - -class SubVariableCondition(BaseModel): - logical_operator: Literal["and", "or"] - conditions: list[SubCondition] = Field(default_factory=list) - - -class Condition(BaseModel): - variable_selector: list[str] - comparison_operator: SupportedComparisonOperator - value: str | Sequence[str] | bool | None = None - sub_variable_condition: SubVariableCondition | None = None diff --git a/api/graphon/utils/condition/processor.py b/api/graphon/utils/condition/processor.py deleted file mode 100644 index 03535927cb..0000000000 --- a/api/graphon/utils/condition/processor.py +++ /dev/null @@ -1,504 +0,0 @@ -import json -from collections.abc import Mapping, Sequence -from typing import Literal, NamedTuple - -from graphon.file import FileAttribute, file_manager -from graphon.runtime import VariablePool -from graphon.variables import ArrayFileSegment -from graphon.variables.segments import ArrayBooleanSegment, BooleanSegment - -from .entities import Condition, SubCondition, SupportedComparisonOperator - - -def _convert_to_bool(value: object) -> bool: - if isinstance(value, int): - return bool(value) - - if isinstance(value, str): - loaded = json.loads(value) - if isinstance(loaded, (int, bool)): - return bool(loaded) - - raise TypeError(f"unexpected value: type={type(value)}, value={value}") - - -class ConditionCheckResult(NamedTuple): - inputs: Sequence[Mapping[str, object]] - group_results: Sequence[bool] - final_result: bool - - -class ConditionProcessor: - def process_conditions( - self, - *, - variable_pool: VariablePool, - conditions: Sequence[Condition], - operator: Literal["and", "or"], - ) -> ConditionCheckResult: - input_conditions: list[Mapping[str, object]] = [] - group_results: list[bool] = [] - - for condition in conditions: - variable = variable_pool.get(condition.variable_selector) - if variable is None: - raise ValueError(f"Variable {condition.variable_selector} not found") - - if isinstance(variable, ArrayFileSegment) and condition.comparison_operator in { - "contains", - "not contains", - "all of", - }: - # check sub conditions - if not condition.sub_variable_condition: - raise ValueError("Sub variable is required") - result = _process_sub_conditions( - variable=variable, - sub_conditions=condition.sub_variable_condition.conditions, - operator=condition.sub_variable_condition.logical_operator, - ) - elif condition.comparison_operator in { - "exists", - "not exists", - }: - result = _evaluate_condition( - value=variable.value, - operator=condition.comparison_operator, - expected=None, - ) - else: - actual_value = variable.value if variable else None - expected_value: str | Sequence[str] | bool | list[bool] | None = condition.value - if isinstance(expected_value, str): - expected_value = variable_pool.convert_template(expected_value).text - # Here we need to explicit convet the input string to boolean. - if isinstance(variable, (BooleanSegment, ArrayBooleanSegment)) and expected_value is not None: - # The following two lines is for compatibility with existing workflows. - if isinstance(expected_value, list): - expected_value = [_convert_to_bool(i) for i in expected_value] - else: - expected_value = _convert_to_bool(expected_value) - input_conditions.append( - { - "actual_value": actual_value, - "expected_value": expected_value, - "comparison_operator": condition.comparison_operator, - } - ) - result = _evaluate_condition( - value=actual_value, - operator=condition.comparison_operator, - expected=expected_value, - ) - group_results.append(result) - # Implemented short-circuit evaluation for logical conditions - if (operator == "and" and not result) or (operator == "or" and result): - final_result = result - return ConditionCheckResult(input_conditions, group_results, final_result) - - final_result = all(group_results) if operator == "and" else any(group_results) - return ConditionCheckResult(input_conditions, group_results, final_result) - - -def _evaluate_condition( - *, - operator: SupportedComparisonOperator, - value: object, - expected: str | Sequence[str] | bool | Sequence[bool] | None, -) -> bool: - match operator: - case "contains": - return _assert_contains(value=value, expected=expected) - case "not contains": - return _assert_not_contains(value=value, expected=expected) - case "start with": - return _assert_start_with(value=value, expected=expected) - case "end with": - return _assert_end_with(value=value, expected=expected) - case "is": - return _assert_is(value=value, expected=expected) - case "is not": - return _assert_is_not(value=value, expected=expected) - case "empty": - return _assert_empty(value=value) - case "not empty": - return _assert_not_empty(value=value) - case "=": - return _assert_equal(value=value, expected=expected) - case "≠": - return _assert_not_equal(value=value, expected=expected) - case ">": - return _assert_greater_than(value=value, expected=expected) - case "<": - return _assert_less_than(value=value, expected=expected) - case "≥": - return _assert_greater_than_or_equal(value=value, expected=expected) - case "≤": - return _assert_less_than_or_equal(value=value, expected=expected) - case "null": - return _assert_null(value=value) - case "not null": - return _assert_not_null(value=value) - case "in": - return _assert_in(value=value, expected=expected) - case "not in": - return _assert_not_in(value=value, expected=expected) - case "all of" if isinstance(expected, list): - # Type narrowing: at this point expected is a list, could be list[str] or list[bool] - if all(isinstance(item, str) for item in expected): - # Create a new typed list to satisfy type checker - str_list: list[str] = [item for item in expected if isinstance(item, str)] - return _assert_all_of(value=value, expected=str_list) - elif all(isinstance(item, bool) for item in expected): - # Create a new typed list to satisfy type checker - bool_list: list[bool] = [item for item in expected if isinstance(item, bool)] - return _assert_all_of_bool(value=value, expected=bool_list) - else: - raise ValueError("all of operator expects homogeneous list of strings or booleans") - case "exists": - return _assert_exists(value=value) - case "not exists": - return _assert_not_exists(value=value) - case _: - raise ValueError(f"Unsupported operator: {operator}") - - -def _assert_contains(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(value, (str, list)): - raise ValueError("Invalid actual value type: string or array") - - # Type checking ensures value is str or list at this point - if isinstance(value, str): - if not isinstance(expected, str): - expected = str(expected) - if expected not in value: - return False - else: # value is list - if expected not in value: - return False - return True - - -def _assert_not_contains(*, value: object, expected: object) -> bool: - if not value: - return True - - if not isinstance(value, (str, list)): - raise ValueError("Invalid actual value type: string or array") - - # Type checking ensures value is str or list at this point - if isinstance(value, str): - if not isinstance(expected, str): - expected = str(expected) - if expected in value: - return False - else: # value is list - if expected in value: - return False - return True - - -def _assert_start_with(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(value, str): - raise ValueError("Invalid actual value type: string") - - if not isinstance(expected, str): - raise ValueError("Expected value must be a string for startswith") - if not value.startswith(expected): - return False - return True - - -def _assert_end_with(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(value, str): - raise ValueError("Invalid actual value type: string") - - if not isinstance(expected, str): - raise ValueError("Expected value must be a string for endswith") - if not value.endswith(expected): - return False - return True - - -def _assert_is(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (str, bool)): - raise ValueError("Invalid actual value type: string or boolean") - - if value != expected: - return False - return True - - -def _assert_is_not(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (str, bool)): - raise ValueError("Invalid actual value type: string or boolean") - - if value == expected: - return False - return True - - -def _assert_empty(*, value: object) -> bool: - if not value: - return True - return False - - -def _assert_not_empty(*, value: object) -> bool: - if value: - return True - return False - - -def _normalize_numeric_values(value: int | float, expected: object) -> tuple[int | float, int | float]: - """ - Normalize value and expected to compatible numeric types for comparison. - - Args: - value: The actual numeric value (int or float) - expected: The expected value (int, float, or str) - - Returns: - A tuple of (normalized_value, normalized_expected) with compatible types - - Raises: - ValueError: If expected cannot be converted to a number - """ - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to number") - - # Convert expected to appropriate numeric type - if isinstance(expected, str): - # Try to convert to float first to handle decimal strings - try: - expected_float = float(expected) - except ValueError as e: - raise ValueError(f"Cannot convert '{expected}' to number") from e - - # If value is int and expected is a whole number, keep as int comparison - if isinstance(value, int) and expected_float.is_integer(): - return value, int(expected_float) - else: - # Otherwise convert value to float for comparison - return float(value) if isinstance(value, int) else value, expected_float - elif isinstance(expected, float): - # If expected is already float, convert int value to float - return float(value) if isinstance(value, int) else value, expected - else: - # expected is int - return value, expected - - -def _assert_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float, bool)): - raise ValueError("Invalid actual value type: number or boolean") - - # Handle boolean comparison - if isinstance(value, bool): - if not isinstance(expected, (bool, int, str)): - raise ValueError(f"Cannot convert {type(expected)} to bool") - expected = bool(expected) - elif isinstance(value, int): - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to int") - expected = int(expected) - else: - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to float") - expected = float(expected) - - if value != expected: - return False - return True - - -def _assert_not_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float, bool)): - raise ValueError("Invalid actual value type: number or boolean") - - # Handle boolean comparison - if isinstance(value, bool): - if not isinstance(expected, (bool, int, str)): - raise ValueError(f"Cannot convert {type(expected)} to bool") - expected = bool(expected) - elif isinstance(value, int): - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to int") - expected = int(expected) - else: - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to float") - expected = float(expected) - - if value == expected: - return False - return True - - -def _assert_greater_than(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value > expected - - -def _assert_less_than(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value < expected - - -def _assert_greater_than_or_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value >= expected - - -def _assert_less_than_or_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value <= expected - - -def _assert_null(*, value: object) -> bool: - if value is None: - return True - return False - - -def _assert_not_null(*, value: object) -> bool: - if value is not None: - return True - return False - - -def _assert_in(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(expected, list): - raise ValueError("Invalid expected value type: array") - - if value not in expected: - return False - return True - - -def _assert_not_in(*, value: object, expected: object) -> bool: - if not value: - return True - - if not isinstance(expected, list): - raise ValueError("Invalid expected value type: array") - - if value in expected: - return False - return True - - -def _assert_all_of(*, value: object, expected: Sequence[str]) -> bool: - if not value: - return False - - # Ensure value is a container that supports 'in' operator - if not isinstance(value, (list, tuple, set, str)): - return False - - return all(item in value for item in expected) - - -def _assert_all_of_bool(*, value: object, expected: Sequence[bool]) -> bool: - if not value: - return False - - # Ensure value is a container that supports 'in' operator - if not isinstance(value, (list, tuple, set)): - return False - - return all(item in value for item in expected) - - -def _assert_exists(*, value: object) -> bool: - return value is not None - - -def _assert_not_exists(*, value: object) -> bool: - return value is None - - -def _process_sub_conditions( - variable: ArrayFileSegment, - sub_conditions: Sequence[SubCondition], - operator: Literal["and", "or"], -) -> bool: - files = variable.value - group_results: list[bool] = [] - for condition in sub_conditions: - key = FileAttribute(condition.key) - values = [file_manager.get_attr(file=file, attr=key) for file in files] - expected_value = condition.value - if key == FileAttribute.EXTENSION: - if not isinstance(expected_value, str): - raise TypeError("Expected value must be a string when key is FileAttribute.EXTENSION") - if expected_value and not expected_value.startswith("."): - expected_value = "." + expected_value - - normalized_values: list[object] = [] - for value in values: - if value and isinstance(value, str): - if not value.startswith("."): - value = "." + value - normalized_values.append(value) - values = normalized_values - sub_group_results: list[bool] = [ - _evaluate_condition( - value=value, - operator=condition.comparison_operator, - expected=expected_value, - ) - for value in values - ] - # Determine the result based on the presence of "not" in the comparison operator - result = all(sub_group_results) if "not" in condition.comparison_operator else any(sub_group_results) - group_results.append(result) - return all(group_results) if operator == "and" else any(group_results) diff --git a/api/graphon/utils/json_in_md_parser.py b/api/graphon/utils/json_in_md_parser.py deleted file mode 100644 index 4416b4774b..0000000000 --- a/api/graphon/utils/json_in_md_parser.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -import json - - -class OutputParserError(ValueError): - """Raised when a markdown-wrapped JSON payload cannot be parsed or validated.""" - - -def parse_json_markdown(json_string: str) -> dict | list: - """Extract and parse the first JSON object or array embedded in markdown text.""" - json_string = json_string.strip() - starts = ["```json", "```", "``", "`", "{", "["] - ends = ["```", "``", "`", "}", "]"] - end_index = -1 - start_index = 0 - - for start_marker in starts: - start_index = json_string.find(start_marker) - if start_index != -1: - if json_string[start_index] not in ("{", "["): - start_index += len(start_marker) - break - - if start_index != -1: - for end_marker in ends: - end_index = json_string.rfind(end_marker, start_index) - if end_index != -1: - if json_string[end_index] in ("}", "]"): - end_index += 1 - break - - if start_index == -1 or end_index == -1 or start_index >= end_index: - raise ValueError("could not find json block in the output.") - - extracted_content = json_string[start_index:end_index].strip() - return json.loads(extracted_content) - - -def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: - try: - json_obj = parse_json_markdown(text) - except json.JSONDecodeError as exc: - raise OutputParserError(f"got invalid json object. error: {exc}") from exc - - if isinstance(json_obj, list): - if len(json_obj) == 1 and isinstance(json_obj[0], dict): - json_obj = json_obj[0] - else: - raise OutputParserError(f"got invalid return object. obj:{json_obj}") - - for key in expected_keys: - if key not in json_obj: - raise OutputParserError( - f"got invalid return object. expected key `{key}` to be present, but got {json_obj}" - ) - - return json_obj diff --git a/api/graphon/variable_loader.py b/api/graphon/variable_loader.py deleted file mode 100644 index 03db920d3d..0000000000 --- a/api/graphon/variable_loader.py +++ /dev/null @@ -1,75 +0,0 @@ -import abc -from collections.abc import Mapping, Sequence -from typing import Any, Protocol - -from graphon.runtime import VariablePool -from graphon.variables import VariableBase -from graphon.variables.consts import SELECTORS_LENGTH - - -class VariableLoader(Protocol): - """Interface for loading variables based on selectors. - - A `VariableLoader` is responsible for retrieving additional variables required during the execution - of a single node, which are not provided as user inputs. - - TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into - `WorkflowService.single_step_run`, we may get rid of this interface. - """ - - @abc.abstractmethod - def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: - """Load variables based on the provided selectors. If the selectors are empty, - this method should return an empty list. - - The order of the returned variables is not guaranteed. If the caller wants to ensure - a specific order, they should sort the returned list themselves. - - :param: selectors: a list of string list, each inner list should have at least two elements: - - the first element is the node ID, - - the second element is the variable name. - :return: a list of VariableBase objects that match the provided selectors. - """ - pass - - -class _DummyVariableLoader(VariableLoader): - """A dummy implementation of VariableLoader that does not load any variables. - Serves as a placeholder when no variable loading is needed. - """ - - def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: - return [] - - -DUMMY_VARIABLE_LOADER = _DummyVariableLoader() - - -def load_into_variable_pool( - variable_loader: VariableLoader, - variable_pool: VariablePool, - variable_mapping: Mapping[str, Sequence[str]], - user_inputs: Mapping[str, Any], -): - # Loading missing variable from draft var here, and set it into - # variable_pool. - variables_to_load: list[list[str]] = [] - for key, selector in variable_mapping.items(): - # NOTE(QuantumGhost): this logic needs to be in sync with - # `WorkflowEntry.mapping_user_inputs_to_variable_pool`. - node_variable_list = key.split(".") - if len(node_variable_list) < 2: - raise ValueError(f"Invalid variable key: {key}. It should have at least two elements.") - if key in user_inputs: - continue - node_variable_key = ".".join(node_variable_list[1:]) - if node_variable_key in user_inputs: - continue - if variable_pool.get(selector) is None: - variables_to_load.append(list(selector)) - loaded = variable_loader.load_variables(variables_to_load) - for var in loaded: - assert len(var.selector) >= SELECTORS_LENGTH, f"Invalid variable {var}" - # Add variable directly to the pool - # The variable pool expects 2-element selectors [node_id, variable_name] - variable_pool.add([var.selector[0], var.selector[1]], var) diff --git a/api/graphon/variables/__init__.py b/api/graphon/variables/__init__.py deleted file mode 100644 index e9beb6cb95..0000000000 --- a/api/graphon/variables/__init__.py +++ /dev/null @@ -1,82 +0,0 @@ -from .factory import ( - TypeMismatchError, - UnsupportedSegmentTypeError, - build_segment, - build_segment_with_type, - segment_to_variable, -) -from .input_entities import VariableEntity, VariableEntityType -from .segment_group import SegmentGroup -from .segments import ( - ArrayAnySegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, -) -from .types import SegmentType -from .variables import ( - ArrayAnyVariable, - ArrayFileVariable, - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - ArrayVariable, - FileVariable, - FloatVariable, - IntegerVariable, - NoneVariable, - ObjectVariable, - SecretVariable, - StringVariable, - Variable, - VariableBase, -) - -__all__ = [ - "ArrayAnySegment", - "ArrayAnyVariable", - "ArrayFileSegment", - "ArrayFileVariable", - "ArrayNumberSegment", - "ArrayNumberVariable", - "ArrayObjectSegment", - "ArrayObjectVariable", - "ArraySegment", - "ArrayStringSegment", - "ArrayStringVariable", - "ArrayVariable", - "FileSegment", - "FileVariable", - "FloatSegment", - "FloatVariable", - "IntegerSegment", - "IntegerVariable", - "NoneSegment", - "NoneVariable", - "ObjectSegment", - "ObjectVariable", - "SecretVariable", - "Segment", - "SegmentGroup", - "SegmentType", - "StringSegment", - "StringVariable", - "TypeMismatchError", - "UnsupportedSegmentTypeError", - "Variable", - "VariableBase", - "VariableEntity", - "VariableEntityType", - "build_segment", - "build_segment_with_type", - "segment_to_variable", -] diff --git a/api/graphon/variables/consts.py b/api/graphon/variables/consts.py deleted file mode 100644 index 8f3f78f740..0000000000 --- a/api/graphon/variables/consts.py +++ /dev/null @@ -1,7 +0,0 @@ -# The minimal selector length for valid variables. -# -# The first element of the selector is the node id, and the second element is the variable name. -# -# If the selector length is more than 2, the remaining parts are the keys / indexes paths used -# to extract part of the variable value. -SELECTORS_LENGTH = 2 diff --git a/api/graphon/variables/exc.py b/api/graphon/variables/exc.py deleted file mode 100644 index 5cf67c3bac..0000000000 --- a/api/graphon/variables/exc.py +++ /dev/null @@ -1,2 +0,0 @@ -class VariableError(ValueError): - pass diff --git a/api/graphon/variables/factory.py b/api/graphon/variables/factory.py deleted file mode 100644 index ac693914a7..0000000000 --- a/api/graphon/variables/factory.py +++ /dev/null @@ -1,202 +0,0 @@ -"""Graph-owned helpers for converting runtime values, segments, and variables. - -These conversions are part of the `graphon` runtime model and must stay -independent from top-level API factory modules so graph nodes and state -containers can operate without importing application-layer packages. -""" - -from collections.abc import Mapping, Sequence -from typing import Any, cast -from uuid import uuid4 - -from graphon.file import File - -from .segments import ( - ArrayAnySegment, - ArrayBooleanSegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - BooleanSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, -) -from .types import SegmentType -from .variables import ( - ArrayAnyVariable, - ArrayBooleanVariable, - ArrayFileVariable, - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - BooleanVariable, - FileVariable, - FloatVariable, - IntegerVariable, - NoneVariable, - ObjectVariable, - StringVariable, - VariableBase, -) - - -class UnsupportedSegmentTypeError(Exception): - pass - - -class TypeMismatchError(Exception): - pass - - -SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[Any]] = { - ArrayAnySegment: ArrayAnyVariable, - ArrayBooleanSegment: ArrayBooleanVariable, - ArrayFileSegment: ArrayFileVariable, - ArrayNumberSegment: ArrayNumberVariable, - ArrayObjectSegment: ArrayObjectVariable, - ArrayStringSegment: ArrayStringVariable, - BooleanSegment: BooleanVariable, - FileSegment: FileVariable, - FloatSegment: FloatVariable, - IntegerSegment: IntegerVariable, - NoneSegment: NoneVariable, - ObjectSegment: ObjectVariable, - StringSegment: StringVariable, -} - - -def build_segment(value: Any, /) -> Segment: - """Build a runtime segment from a Python value.""" - if value is None: - return NoneSegment() - if isinstance(value, Segment): - return value - if isinstance(value, str): - return StringSegment(value=value) - if isinstance(value, bool): - return BooleanSegment(value=value) - if isinstance(value, int): - return IntegerSegment(value=value) - if isinstance(value, float): - return FloatSegment(value=value) - if isinstance(value, dict): - return ObjectSegment(value=value) - if isinstance(value, File): - return FileSegment(value=value) - if isinstance(value, list): - items = [build_segment(item) for item in value] - types = {item.value_type for item in items} - if all(isinstance(item, ArraySegment) for item in items): - return ArrayAnySegment(value=value) - if len(types) != 1: - if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): - return ArrayNumberSegment(value=value) - return ArrayAnySegment(value=value) - - match types.pop(): - case SegmentType.STRING: - return ArrayStringSegment(value=value) - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return ArrayNumberSegment(value=value) - case SegmentType.BOOLEAN: - return ArrayBooleanSegment(value=value) - case SegmentType.OBJECT: - return ArrayObjectSegment(value=value) - case SegmentType.FILE: - return ArrayFileSegment(value=value) - case SegmentType.NONE: - return ArrayAnySegment(value=value) - case _: - raise ValueError(f"not supported value {value}") - raise ValueError(f"not supported value {value}") - - -_SEGMENT_FACTORY: Mapping[SegmentType, type[Segment]] = { - SegmentType.NONE: NoneSegment, - SegmentType.STRING: StringSegment, - SegmentType.INTEGER: IntegerSegment, - SegmentType.FLOAT: FloatSegment, - SegmentType.FILE: FileSegment, - SegmentType.BOOLEAN: BooleanSegment, - SegmentType.OBJECT: ObjectSegment, - SegmentType.ARRAY_ANY: ArrayAnySegment, - SegmentType.ARRAY_STRING: ArrayStringSegment, - SegmentType.ARRAY_NUMBER: ArrayNumberSegment, - SegmentType.ARRAY_OBJECT: ArrayObjectSegment, - SegmentType.ARRAY_FILE: ArrayFileSegment, - SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment, -} - - -def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: - """Build a segment while enforcing compatibility with the expected runtime type.""" - if value is None: - if segment_type == SegmentType.NONE: - return NoneSegment() - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") - - if isinstance(value, list) and len(value) == 0: - if segment_type == SegmentType.ARRAY_ANY: - return ArrayAnySegment(value=value) - if segment_type == SegmentType.ARRAY_STRING: - return ArrayStringSegment(value=value) - if segment_type == SegmentType.ARRAY_BOOLEAN: - return ArrayBooleanSegment(value=value) - if segment_type == SegmentType.ARRAY_NUMBER: - return ArrayNumberSegment(value=value) - if segment_type == SegmentType.ARRAY_OBJECT: - return ArrayObjectSegment(value=value) - if segment_type == SegmentType.ARRAY_FILE: - return ArrayFileSegment(value=value) - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") - - inferred_type = SegmentType.infer_segment_type(value) - if inferred_type is None: - raise TypeMismatchError( - f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" - ) - if inferred_type == segment_type: - segment_class = _SEGMENT_FACTORY[segment_type] - return segment_class(value_type=segment_type, value=value) - if segment_type == SegmentType.NUMBER and inferred_type in (SegmentType.INTEGER, SegmentType.FLOAT): - segment_class = _SEGMENT_FACTORY[inferred_type] - return segment_class(value_type=inferred_type, value=value) - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") - - -def segment_to_variable( - *, - segment: Segment, - selector: Sequence[str], - id: str | None = None, - name: str | None = None, - description: str = "", -) -> VariableBase: - """Convert a runtime segment into a runtime variable for storage in the pool.""" - if isinstance(segment, VariableBase): - return segment - name = name or selector[-1] - id = id or str(uuid4()) - - segment_type = type(segment) - if segment_type not in SEGMENT_TO_VARIABLE_MAP: - raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") - - variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return cast( - VariableBase, - variable_class( - id=id, - name=name, - description=description, - value=segment.value, - selector=list(selector), - ), - ) diff --git a/api/graphon/variables/input_entities.py b/api/graphon/variables/input_entities.py deleted file mode 100644 index c46ee47714..0000000000 --- a/api/graphon/variables/input_entities.py +++ /dev/null @@ -1,62 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum -from typing import Any - -from jsonschema import Draft7Validator, SchemaError -from pydantic import BaseModel, Field, field_validator - -from graphon.file import FileTransferMethod, FileType - - -class VariableEntityType(StrEnum): - TEXT_INPUT = "text-input" - SELECT = "select" - PARAGRAPH = "paragraph" - NUMBER = "number" - EXTERNAL_DATA_TOOL = "external_data_tool" - FILE = "file" - FILE_LIST = "file-list" - CHECKBOX = "checkbox" - JSON_OBJECT = "json_object" - - -class VariableEntity(BaseModel): - """ - Shared variable entity used by workflow runtime and app configuration. - """ - - # `variable` records the name of the variable in user inputs. - variable: str - label: str - description: str = "" - type: VariableEntityType - required: bool = False - hide: bool = False - default: Any = None - max_length: int | None = None - options: Sequence[str] = Field(default_factory=list) - allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) - allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) - allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) - json_schema: dict[str, Any] | None = Field(default=None) - - @field_validator("description", mode="before") - @classmethod - def convert_none_description(cls, value: Any) -> str: - return value or "" - - @field_validator("options", mode="before") - @classmethod - def convert_none_options(cls, value: Any) -> Sequence[str]: - return value or [] - - @field_validator("json_schema") - @classmethod - def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None: - if schema is None: - return None - try: - Draft7Validator.check_schema(schema) - except SchemaError as error: - raise ValueError(f"Invalid JSON schema: {error.message}") - return schema diff --git a/api/graphon/variables/segment_group.py b/api/graphon/variables/segment_group.py deleted file mode 100644 index b363255b2c..0000000000 --- a/api/graphon/variables/segment_group.py +++ /dev/null @@ -1,22 +0,0 @@ -from .segments import Segment -from .types import SegmentType - - -class SegmentGroup(Segment): - value_type: SegmentType = SegmentType.GROUP - value: list[Segment] - - @property - def text(self): - return "".join([segment.text for segment in self.value]) - - @property - def log(self): - return "".join([segment.log for segment in self.value]) - - @property - def markdown(self): - return "".join([segment.markdown for segment in self.value]) - - def to_object(self): - return [segment.to_object() for segment in self.value] diff --git a/api/graphon/variables/segments.py b/api/graphon/variables/segments.py deleted file mode 100644 index 8902ddc7e9..0000000000 --- a/api/graphon/variables/segments.py +++ /dev/null @@ -1,253 +0,0 @@ -import json -import sys -from collections.abc import Mapping, Sequence -from typing import Annotated, Any, TypeAlias - -from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator - -from graphon.file import File - -from .types import SegmentType - - -class Segment(BaseModel): - """Segment is runtime type used during the execution of workflow. - - Note: this class is abstract, you should use subclasses of this class instead. - """ - - model_config = ConfigDict(frozen=True) - - value_type: SegmentType - value: Any - - @field_validator("value_type") - @classmethod - def validate_value_type(cls, value): - """ - This validator checks if the provided value is equal to the default value of the 'value_type' field. - If the value is different, a ValueError is raised. - """ - if value != cls.model_fields["value_type"].default: - raise ValueError("Cannot modify 'value_type'") - return value - - @property - def text(self) -> str: - return str(self.value) - - @property - def log(self) -> str: - return str(self.value) - - @property - def markdown(self) -> str: - return str(self.value) - - @property - def size(self) -> int: - """ - Return the size of the value in bytes. - """ - return sys.getsizeof(self.value) - - def to_object(self): - return self.value - - -class NoneSegment(Segment): - value_type: SegmentType = SegmentType.NONE - value: None = None - - @property - def text(self) -> str: - return "" - - @property - def log(self) -> str: - return "" - - @property - def markdown(self) -> str: - return "" - - -class StringSegment(Segment): - value_type: SegmentType = SegmentType.STRING - value: str - - -class FloatSegment(Segment): - value_type: SegmentType = SegmentType.FLOAT - value: float - # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. - # The following tests cannot pass. - # - # def test_float_segment_and_nan(): - # nan = float("nan") - # assert nan != nan - # - # f1 = FloatSegment(value=float("nan")) - # f2 = FloatSegment(value=float("nan")) - # assert f1 != f2 - # - # f3 = FloatSegment(value=nan) - # f4 = FloatSegment(value=nan) - # assert f3 != f4 - - -class IntegerSegment(Segment): - value_type: SegmentType = SegmentType.INTEGER - value: int - - -class ObjectSegment(Segment): - value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Any] - - @property - def text(self) -> str: - return json.dumps(self.model_dump()["value"], ensure_ascii=False) - - @property - def log(self) -> str: - return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) - - @property - def markdown(self) -> str: - return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) - - -class ArraySegment(Segment): - @property - def text(self) -> str: - # Return empty string for empty arrays instead of "[]" - if not self.value: - return "" - return super().text - - @property - def markdown(self) -> str: - items = [] - for item in self.value: - items.append(f"- {item}") - return "\n".join(items) - - -class FileSegment(Segment): - value_type: SegmentType = SegmentType.FILE - value: File - - @property - def markdown(self) -> str: - return self.value.markdown - - @property - def log(self) -> str: - return "" - - @property - def text(self) -> str: - return "" - - -class BooleanSegment(Segment): - value_type: SegmentType = SegmentType.BOOLEAN - value: bool - - -class ArrayAnySegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Any] - - -class ArrayStringSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[str] - - @property - def text(self) -> str: - # Return empty string for empty arrays instead of "[]" - if not self.value: - return "" - return json.dumps(self.value, ensure_ascii=False) - - -class ArrayNumberSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[float | int] - - -class ArrayObjectSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[Mapping[str, Any]] - - -class ArrayFileSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[File] - - @property - def markdown(self) -> str: - items = [] - for item in self.value: - items.append(item.markdown) - return "\n".join(items) - - @property - def log(self) -> str: - return "" - - @property - def text(self) -> str: - return "" - - -class ArrayBooleanSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_BOOLEAN - value: Sequence[bool] - - -def get_segment_discriminator(v: Any) -> SegmentType | None: - if isinstance(v, Segment): - return v.value_type - elif isinstance(v, dict): - value_type = v.get("value_type") - if value_type is None: - return None - try: - seg_type = SegmentType(value_type) - except ValueError: - return None - return seg_type - else: - # return None if the discriminator value isn't found - return None - - -# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic. -# Use `Segment` for type hinting when serialization is not required. -# -# Note: -# - All variants in `SegmentUnion` must inherit from the `Segment` class. -# - The union must include all non-abstract subclasses of `Segment`, except: -# - `SegmentGroup`, which is not added to the variable pool. -# - `VariableBase` and its subclasses, which are handled by `Variable`. -SegmentUnion: TypeAlias = Annotated[ - ( - Annotated[NoneSegment, Tag(SegmentType.NONE)] - | Annotated[StringSegment, Tag(SegmentType.STRING)] - | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] - | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] - | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] - | Annotated[FileSegment, Tag(SegmentType.FILE)] - | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)] - | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] - | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] - | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] - | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] - | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] - | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)] - ), - Discriminator(get_segment_discriminator), -] diff --git a/api/graphon/variables/types.py b/api/graphon/variables/types.py deleted file mode 100644 index 949a693ad2..0000000000 --- a/api/graphon/variables/types.py +++ /dev/null @@ -1,273 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from enum import StrEnum -from typing import TYPE_CHECKING, Any - -from graphon.file.models import File - -if TYPE_CHECKING: - from graphon.variables.segments import Segment - - -class ArrayValidation(StrEnum): - """Strategy for validating array elements. - - Note: - The `NONE` and `FIRST` strategies are primarily for compatibility purposes. - Avoid using them in new code whenever possible. - """ - - # Skip element validation (only check array container) - NONE = "none" - - # Validate the first element (if array is non-empty) - FIRST = "first" - - # Validate all elements in the array. - ALL = "all" - - -class SegmentType(StrEnum): - NUMBER = "number" - INTEGER = "integer" - FLOAT = "float" - STRING = "string" - OBJECT = "object" - SECRET = "secret" - - FILE = "file" - BOOLEAN = "boolean" - - ARRAY_ANY = "array[any]" - ARRAY_STRING = "array[string]" - ARRAY_NUMBER = "array[number]" - ARRAY_OBJECT = "array[object]" - ARRAY_FILE = "array[file]" - ARRAY_BOOLEAN = "array[boolean]" - - NONE = "none" - - GROUP = "group" - - def is_array_type(self) -> bool: - return self in _ARRAY_TYPES - - @classmethod - def infer_segment_type(cls, value: Any) -> SegmentType | None: - """ - Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. - - Returns `None` if no appropriate `SegmentType` can be determined for the given `value`. - For example, this may occur if the input is a generic Python object of type `object`. - """ - - if isinstance(value, list): - elem_types: set[SegmentType] = set() - for i in value: - segment_type = cls.infer_segment_type(i) - if segment_type is None: - return None - - elem_types.add(segment_type) - - if len(elem_types) != 1: - if elem_types.issubset(_NUMERICAL_TYPES): - return SegmentType.ARRAY_NUMBER - return SegmentType.ARRAY_ANY - elif all(i.is_array_type() for i in elem_types): - return SegmentType.ARRAY_ANY - match elem_types.pop(): - case SegmentType.STRING: - return SegmentType.ARRAY_STRING - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return SegmentType.ARRAY_NUMBER - case SegmentType.OBJECT: - return SegmentType.ARRAY_OBJECT - case SegmentType.FILE: - return SegmentType.ARRAY_FILE - case SegmentType.NONE: - return SegmentType.ARRAY_ANY - case SegmentType.BOOLEAN: - return SegmentType.ARRAY_BOOLEAN - case _: - # This should be unreachable. - raise ValueError(f"not supported value {value}") - if value is None: - return SegmentType.NONE - # Important: The check for `bool` must precede the check for `int`, - # as `bool` is a subclass of `int` in Python's type hierarchy. - elif isinstance(value, bool): - return SegmentType.BOOLEAN - elif isinstance(value, int): - return SegmentType.INTEGER - elif isinstance(value, float): - return SegmentType.FLOAT - elif isinstance(value, str): - return SegmentType.STRING - elif isinstance(value, dict): - return SegmentType.OBJECT - elif isinstance(value, File): - return SegmentType.FILE - else: - return None - - def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool: - if not isinstance(value, list): - return False - # Skip element validation if array is empty - if len(value) == 0: - return True - if self == SegmentType.ARRAY_ANY: - return True - element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self] - - if array_validation == ArrayValidation.NONE: - return True - elif array_validation == ArrayValidation.FIRST: - return element_type.is_valid(value[0]) - else: - return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value) - - def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.ALL) -> bool: - """ - Check if a value matches the segment type. - Users of `SegmentType` should call this method, instead of using - `isinstance` manually. - - Args: - value: The value to validate - array_validation: Validation strategy for array types (ignored for non-array types) - - Returns: - True if the value matches the type under the given validation strategy - """ - if self.is_array_type(): - return self._validate_array(value, array_validation) - # Important: The check for `bool` must precede the check for `int`, - # as `bool` is a subclass of `int` in Python's type hierarchy. - elif self == SegmentType.BOOLEAN: - return isinstance(value, bool) - elif self in [SegmentType.INTEGER, SegmentType.FLOAT, SegmentType.NUMBER]: - return isinstance(value, (int, float)) - elif self == SegmentType.STRING: - return isinstance(value, str) - elif self == SegmentType.OBJECT: - return isinstance(value, dict) - elif self == SegmentType.SECRET: - return isinstance(value, str) - elif self == SegmentType.FILE: - return isinstance(value, File) - elif self == SegmentType.NONE: - return value is None - elif self == SegmentType.GROUP: - from .segment_group import SegmentGroup - from .segments import Segment - - if isinstance(value, SegmentGroup): - return all(isinstance(item, Segment) for item in value.value) - - if isinstance(value, list): - return all(isinstance(item, Segment) for item in value) - - return False - else: - raise AssertionError("this statement should be unreachable.") - - @staticmethod - def cast_value(value: Any, type_: SegmentType): - # Cast Python's `bool` type to `int` when the runtime type requires - # an integer or number. - # - # This ensures compatibility with existing workflows that may use `bool` as - # `int`, since in Python's type system, `bool` is a subtype of `int`. - # - # This function exists solely to maintain compatibility with existing workflows. - # It should not be used to compromise the integrity of the runtime type system. - # No additional casting rules should be introduced to this function. - - if type_ in ( - SegmentType.INTEGER, - SegmentType.NUMBER, - ) and isinstance(value, bool): - return int(value) - if type_ == SegmentType.ARRAY_NUMBER and all(isinstance(i, bool) for i in value): - return [int(i) for i in value] - return value - - def exposed_type(self) -> SegmentType: - """Returns the type exposed to the frontend. - - The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here. - """ - if self in (SegmentType.INTEGER, SegmentType.FLOAT): - return SegmentType.NUMBER - return self - - def element_type(self) -> SegmentType | None: - """Return the element type of the current segment type, or `None` if the element type is undefined. - - Raises: - ValueError: If the current segment type is not an array type. - - Note: - For certain array types, such as `SegmentType.ARRAY_ANY`, their element types are not defined - by the runtime system. In such cases, this method will return `None`. - """ - if not self.is_array_type(): - raise ValueError(f"element_type is only supported by array type, got {self}") - return _ARRAY_ELEMENT_TYPES_MAPPING.get(self) - - @staticmethod - def get_zero_value(t: SegmentType) -> Segment: - # Lazy import to avoid circular dependency between segment types and factory helpers. - from graphon.variables.factory import build_segment, build_segment_with_type - - match t: - case ( - SegmentType.ARRAY_OBJECT - | SegmentType.ARRAY_ANY - | SegmentType.ARRAY_STRING - | SegmentType.ARRAY_NUMBER - | SegmentType.ARRAY_BOOLEAN - ): - return build_segment_with_type(t, []) - case SegmentType.OBJECT: - return build_segment({}) - case SegmentType.STRING: - return build_segment("") - case SegmentType.INTEGER: - return build_segment(0) - case SegmentType.FLOAT: - return build_segment(0.0) - case SegmentType.NUMBER: - return build_segment(0) - case SegmentType.BOOLEAN: - return build_segment(False) - case _: - raise ValueError(f"unsupported variable type: {t}") - - -_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = { - # ARRAY_ANY does not have corresponding element type. - SegmentType.ARRAY_STRING: SegmentType.STRING, - SegmentType.ARRAY_NUMBER: SegmentType.NUMBER, - SegmentType.ARRAY_OBJECT: SegmentType.OBJECT, - SegmentType.ARRAY_FILE: SegmentType.FILE, - SegmentType.ARRAY_BOOLEAN: SegmentType.BOOLEAN, -} - -_ARRAY_TYPES = frozenset( - list(_ARRAY_ELEMENT_TYPES_MAPPING.keys()) - + [ - SegmentType.ARRAY_ANY, - ] -) - -_NUMERICAL_TYPES = frozenset( - [ - SegmentType.NUMBER, - SegmentType.INTEGER, - SegmentType.FLOAT, - ] -) diff --git a/api/graphon/variables/utils.py b/api/graphon/variables/utils.py deleted file mode 100644 index 8e738f8fd5..0000000000 --- a/api/graphon/variables/utils.py +++ /dev/null @@ -1,33 +0,0 @@ -from collections.abc import Iterable, Sequence -from typing import Any - -import orjson - -from .segment_group import SegmentGroup -from .segments import ArrayFileSegment, FileSegment, Segment - - -def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: - selectors = [node_id, name] - if paths: - selectors.extend(paths) - return selectors - - -def segment_orjson_default(o: Any): - """Default function for orjson serialization of Segment types""" - if isinstance(o, ArrayFileSegment): - return [v.model_dump() for v in o.value] - elif isinstance(o, FileSegment): - return o.value.model_dump() - elif isinstance(o, SegmentGroup): - return [segment_orjson_default(seg) for seg in o.value] - elif isinstance(o, Segment): - return o.value - raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable") - - -def dumps_with_segments(obj: Any, ensure_ascii: bool = False) -> str: - """JSON dumps with segment support using orjson""" - option = orjson.OPT_NON_STR_KEYS - return orjson.dumps(obj, default=segment_orjson_default, option=option).decode("utf-8") diff --git a/api/graphon/variables/variables.py b/api/graphon/variables/variables.py deleted file mode 100644 index af866283da..0000000000 --- a/api/graphon/variables/variables.py +++ /dev/null @@ -1,172 +0,0 @@ -from collections.abc import Sequence -from typing import Annotated, Any, TypeAlias -from uuid import uuid4 - -from pydantic import BaseModel, Discriminator, Field, Tag - -from .segments import ( - ArrayAnySegment, - ArrayBooleanSegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - BooleanSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, - get_segment_discriminator, -) -from .types import SegmentType - - -def _obfuscated_token(token: str) -> str: - if not token: - return token - if len(token) <= 8: - return "*" * 20 - return token[:6] + "*" * 12 + token[-2:] - - -class VariableBase(Segment): - """ - A variable is a segment that has a name. - - It is mainly used to store segments and their selector in VariablePool. - - Note: this class is abstract, you should use subclasses of this class instead. - """ - - id: str = Field( - default_factory=lambda: str(uuid4()), - description="Unique identity for variable.", - ) - name: str - description: str = Field(default="", description="Description of the variable.") - selector: Sequence[str] = Field(default_factory=list) - - -class StringVariable(StringSegment, VariableBase): - pass - - -class FloatVariable(FloatSegment, VariableBase): - pass - - -class IntegerVariable(IntegerSegment, VariableBase): - pass - - -class ObjectVariable(ObjectSegment, VariableBase): - pass - - -class ArrayVariable(ArraySegment, VariableBase): - pass - - -class ArrayAnyVariable(ArrayAnySegment, ArrayVariable): - pass - - -class ArrayStringVariable(ArrayStringSegment, ArrayVariable): - pass - - -class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable): - pass - - -class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable): - pass - - -class SecretVariable(StringVariable): - value_type: SegmentType = SegmentType.SECRET - - @property - def log(self) -> str: - return _obfuscated_token(self.value) - - -class NoneVariable(NoneSegment, VariableBase): - value_type: SegmentType = SegmentType.NONE - value: None = None - - -class FileVariable(FileSegment, VariableBase): - pass - - -class BooleanVariable(BooleanSegment, VariableBase): - pass - - -class ArrayFileVariable(ArrayFileSegment, ArrayVariable): - pass - - -class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable): - pass - - -class RAGPipelineVariable(BaseModel): - belong_to_node_id: str = Field(description="belong to which node id, shared means public") - type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list") - label: str = Field(description="label") - description: str | None = Field(description="description", default="") - variable: str = Field(description="variable key", default="") - max_length: int | None = Field( - description="max length, applicable to text-input, paragraph, and file-list", default=0 - ) - default_value: Any = Field(description="default value", default="") - placeholder: str | None = Field(description="placeholder", default="") - unit: str | None = Field(description="unit, applicable to Number", default="") - tooltips: str | None = Field(description="helpful text", default="") - allowed_file_types: list[str] | None = Field( - description="image, document, audio, video, custom.", default_factory=list - ) - allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list) - allowed_file_upload_methods: list[str] | None = Field( - description="remote_url, local_file, tool_file.", default_factory=list - ) - required: bool = Field(description="optional, default false", default=False) - options: list[str] | None = Field(default_factory=list) - - -class RAGPipelineVariableInput(BaseModel): - variable: RAGPipelineVariable - value: Any - - -# The `Variable` type is used to enable serialization and deserialization with Pydantic. -# Use `VariableBase` for type hinting when serialization is not required. -# -# Note: -# - All variants in `Variable` must inherit from the `VariableBase` class. -# - The union must include all non-abstract subclasses of `VariableBase`. -Variable: TypeAlias = Annotated[ - ( - Annotated[NoneVariable, Tag(SegmentType.NONE)] - | Annotated[StringVariable, Tag(SegmentType.STRING)] - | Annotated[FloatVariable, Tag(SegmentType.FLOAT)] - | Annotated[IntegerVariable, Tag(SegmentType.INTEGER)] - | Annotated[ObjectVariable, Tag(SegmentType.OBJECT)] - | Annotated[FileVariable, Tag(SegmentType.FILE)] - | Annotated[BooleanVariable, Tag(SegmentType.BOOLEAN)] - | Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)] - | Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)] - | Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)] - | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)] - | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)] - | Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)] - | Annotated[SecretVariable, Tag(SegmentType.SECRET)] - ), - Discriminator(get_segment_discriminator), -] diff --git a/api/graphon/workflow_type_encoder.py b/api/graphon/workflow_type_encoder.py deleted file mode 100644 index 7cdc83ebdb..0000000000 --- a/api/graphon/workflow_type_encoder.py +++ /dev/null @@ -1,49 +0,0 @@ -from collections.abc import Mapping -from decimal import Decimal -from typing import Any, overload - -from pydantic import BaseModel - -from graphon.file.models import File -from graphon.variables import Segment - - -class WorkflowRuntimeTypeConverter: - @overload - def to_json_encodable(self, value: Mapping[str, Any]) -> Mapping[str, Any]: ... - @overload - def to_json_encodable(self, value: None) -> None: ... - - def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: - """Convert runtime values to JSON-serializable structures.""" - - result = self.value_to_json_encodable_recursive(value) - if isinstance(result, Mapping) or result is None: - return result - return {} - - def value_to_json_encodable_recursive(self, value: Any): - if value is None: - return value - if isinstance(value, (bool, int, str, float)): - return value - if isinstance(value, Decimal): - # Convert Decimal to float for JSON serialization - return float(value) - if isinstance(value, Segment): - return self.value_to_json_encodable_recursive(value.value) - if isinstance(value, File): - return value.to_dict() - if isinstance(value, BaseModel): - return value.model_dump(mode="json") - if isinstance(value, dict): - res = {} - for k, v in value.items(): - res[k] = self.value_to_json_encodable_recursive(v) - return res - if isinstance(value, list): - res_list = [] - for item in value: - res_list.append(self.value_to_json_encodable_recursive(item)) - return res_list - return value diff --git a/api/libs/helper.py b/api/libs/helper.py index b1815859a5..a7b3da77ff 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -16,14 +16,14 @@ from zoneinfo import available_timezones from flask import Response, stream_with_context from flask_restx import fields +from graphon.file import helpers as file_helpers +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel from pydantic.functional_validators import AfterValidator from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from extensions.ext_redis import redis_client -from graphon.file import helpers as file_helpers -from graphon.model_runtime.utils.encoders import jsonable_encoder if TYPE_CHECKING: from models import Account diff --git a/api/models/enums.py b/api/models/enums.py index cdec7b2f12..bf2e927f00 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -158,6 +158,15 @@ class FeedbackFromSource(StrEnum): ADMIN = "admin" +class CustomizeTokenStrategy(StrEnum): + """Site token customization strategy""" + + MUST = "must" + ALLOW = "allow" + NOT_ALLOW = "not_allow" + UUID = "uuid" + + class FeedbackRating(StrEnum): """MessageFeedback rating""" @@ -314,6 +323,13 @@ class MessageChainType(StrEnum): SYSTEM = "system" +class PromptType(StrEnum): + """Prompt configuration type""" + + SIMPLE = "simple" + ADVANCED = "advanced" + + class ProviderQuotaType(StrEnum): PAID = "paid" """hosted paid quota""" diff --git a/api/models/human_input.py b/api/models/human_input.py index b4c7a634b6..79c5d62f6a 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -3,11 +3,11 @@ from enum import StrEnum from typing import Annotated, Literal, Self, final import sqlalchemy as sa +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship from core.workflow.human_input_compat import DeliveryMethodType -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.helper import generate_string from .base import Base, DefaultFieldsMixin diff --git a/api/models/model.py b/api/models/model.py index bcb142db56..066d2acdce 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -14,6 +14,9 @@ from uuid import uuid4 import sqlalchemy as sa from flask import request from flask_login import UserMixin # type: ignore[import-untyped] +from graphon.enums import WorkflowExecutionStatus +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.file import helpers as file_helpers from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import TypedDict @@ -22,9 +25,6 @@ from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file from extensions.storage.storage_type import StorageType -from graphon.enums import WorkflowExecutionStatus -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType -from graphon.file import helpers as file_helpers from libs.helper import generate_string # type: ignore[import-not-found] from libs.uuid_utils import uuidv7 from models.utils.file_input_compat import build_file_from_input_mapping @@ -40,12 +40,14 @@ from .enums import ( ConversationFromSource, ConversationStatus, CreatorUserRole, + CustomizeTokenStrategy, FeedbackFromSource, FeedbackRating, InvokeFrom, MessageChainType, MessageFileBelongsTo, MessageStatus, + PromptType, ProviderQuotaType, TagType, ) @@ -649,8 +651,11 @@ class AppModelConfig(TypeBase): agent_mode: Mapped[str | None] = mapped_column(LongText, default=None) sensitive_word_avoidance: Mapped[str | None] = mapped_column(LongText, default=None) retriever_resource: Mapped[str | None] = mapped_column(LongText, default=None) - prompt_type: Mapped[str] = mapped_column( - String(255), nullable=False, server_default=sa.text("'simple'"), default="simple" + prompt_type: Mapped[PromptType] = mapped_column( + EnumText(PromptType, length=255), + nullable=False, + server_default=sa.text("'simple'"), + default=PromptType.SIMPLE, ) chat_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None) completion_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None) @@ -802,7 +807,7 @@ class AppModelConfig(TypeBase): "dataset_query_variable": self.dataset_query_variable, "pre_prompt": self.pre_prompt, "agent_mode": self.agent_mode_dict, - "prompt_type": self.prompt_type, + "prompt_type": self.prompt_type.value if isinstance(self.prompt_type, PromptType) else self.prompt_type, "chat_prompt_config": self.chat_prompt_config_dict, "completion_prompt_config": self.completion_prompt_config_dict, "dataset_configs": self.dataset_configs_dict, @@ -846,7 +851,7 @@ class AppModelConfig(TypeBase): self.retriever_resource = ( json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None ) - self.prompt_type = model_config.get("prompt_type", "simple") + self.prompt_type = PromptType(model_config.get("prompt_type", "simple")) self.chat_prompt_config = ( json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None ) @@ -2084,7 +2089,9 @@ class Site(Base): use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", LongText, default="") customize_domain = mapped_column(String(255)) - customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False) + customize_token_strategy: Mapped[CustomizeTokenStrategy] = mapped_column( + EnumText(CustomizeTokenStrategy, length=255), nullable=False + ) prompt_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) status: Mapped[AppStatus] = mapped_column( EnumText(AppStatus, length=255), nullable=False, server_default=sa.text("'normal'"), default=AppStatus.NORMAL diff --git a/api/models/tools.py b/api/models/tools.py index 63b27b9413..d8731fb8a8 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -11,6 +11,7 @@ from deprecated import deprecated from sqlalchemy import ForeignKey, String, func, select from sqlalchemy.orm import Mapped, mapped_column +from core.plugin.entities.plugin_daemon import CredentialType from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( @@ -109,8 +110,11 @@ class BuiltinToolProvider(TypeBase): ) is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False) # credential type, e.g., "api-key", "oauth2" - credential_type: Mapped[str] = mapped_column( - String(32), nullable=False, server_default=sa.text("'api-key'"), default="api-key" + credential_type: Mapped[CredentialType] = mapped_column( + EnumText(CredentialType, length=32), + nullable=False, + server_default=sa.text("'api-key'"), + default=CredentialType.API_KEY, ) expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"), default=-1) diff --git a/api/models/trigger.py b/api/models/trigger.py index 627b854060..5233a6e271 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -102,7 +102,9 @@ class TriggerSubscription(TypeBase): credentials: Mapped[TriggerCredentials] = mapped_column( sa.JSON, nullable=False, comment="Subscription credentials JSON" ) - credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key") + credential_type: Mapped[CredentialType] = mapped_column( + EnumText(CredentialType, length=50), nullable=False, comment="oauth or api_key" + ) credential_expires_at: Mapped[int] = mapped_column( Integer, default=-1, comment="OAuth token expiration timestamp, -1 for never" ) @@ -144,7 +146,7 @@ class TriggerSubscription(TypeBase): endpoint=generate_plugin_trigger_endpoint_url(self.endpoint_id), parameters=self.parameters, properties=self.properties, - credential_type=CredentialType(self.credential_type), + credential_type=self.credential_type, credentials=self.credentials, workflows_in_use=-1, ) diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py index dee1cc507a..f71583c1cd 100644 --- a/api/models/utils/file_input_compat.py +++ b/api/models/utils/file_input_compat.py @@ -4,9 +4,10 @@ from collections.abc import Callable, Mapping from functools import lru_cache from typing import Any -from core.workflow.file_reference import parse_file_reference from graphon.file import File, FileTransferMethod +from core.workflow.file_reference import parse_file_reference + @lru_cache(maxsize=1) def _get_file_access_controller(): diff --git a/api/models/workflow.py b/api/models/workflow.py index d15bf71d39..f8868cb73c 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -8,6 +8,19 @@ from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast from uuid import uuid4 import sqlalchemy as sa +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import ( + BuiltinNodeTypes, + NodeType, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import File +from graphon.file.constants import maybe_file_object +from graphon.variables import utils as variable_utils +from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from sqlalchemy import ( DateTime, Index, @@ -31,13 +44,6 @@ from core.workflow.variable_prefixes import ( ) from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey -from graphon.file.constants import maybe_file_object -from graphon.file.models import File -from graphon.variables import utils as variable_utils -from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -47,10 +53,11 @@ if TYPE_CHECKING: from .model import AppMode, UploadFile +from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase + from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter from factories import variable_factory -from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase from libs import helper from .account import Account @@ -941,7 +948,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo inputs: Mapped[str | None] = mapped_column(LongText) process_data: Mapped[str | None] = mapped_column(LongText) outputs: Mapped[str | None] = mapped_column(LongText) - status: Mapped[str] = mapped_column(String(255)) + status: Mapped[WorkflowNodeExecutionStatus] = mapped_column(EnumText(WorkflowNodeExecutionStatus, length=255)) error: Mapped[str | None] = mapped_column(LongText) elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0")) execution_metadata: Mapped[str | None] = mapped_column(LongText) @@ -1460,8 +1467,6 @@ class WorkflowDraftVariable(Base): # From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than # 80 chars. - # - # ref: api/graphon/entities/variable_pool.py:18 name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column( sa.String(255), diff --git a/api/pyproject.toml b/api/pyproject.toml index 0398376ee2..f737d0699f 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -8,7 +8,7 @@ dependencies = [ "arize-phoenix-otel~=0.15.0", "azure-identity==1.25.3", "beautifulsoup4==4.14.3", - "boto3==1.42.73", + "boto3==1.42.78", "bs4~=0.0.1", "cachetools~=5.3.0", "celery~=5.6.2", @@ -28,11 +28,11 @@ dependencies = [ "google-auth-httplib2==0.3.0", "google-cloud-aiplatform>=1.123.0", "googleapis-common-protos>=1.65.0", - "gunicorn~=25.1.0", + "graphon>=0.1.2", + "gunicorn~=25.3.0", "httpx[socks]~=0.28.0", "jieba==0.42.1", "json-repair>=0.55.1", - "jsonschema>=4.25.1", "langfuse~=2.51.3", "langsmith~=0.7.16", "markdown~=3.10.2", @@ -63,7 +63,6 @@ dependencies = [ "psycopg2-binary~=2.9.6", "pycryptodome==3.23.0", "pydantic~=2.12.5", - "pydantic-extra-types~=2.11.0", "pydantic-settings~=2.13.1", "pyjwt~=2.12.0", "pypdfium2==5.6.0", @@ -71,7 +70,7 @@ dependencies = [ "python-dotenv==1.2.2", "pyyaml~=6.0.1", "readabilipy~=0.3.0", - "redis[hiredis]~=7.3.0", + "redis[hiredis]~=7.4.0", "resend~=2.26.0", "sentry-sdk[flask]~=2.55.0", "sqlalchemy~=2.0.29", @@ -81,7 +80,6 @@ dependencies = [ "unstructured[docx,epub,md,ppt,pptx]~=0.21.5", "pypandoc~=1.13", "yarl~=1.23.0", - "webvtt-py~=0.5.1", "sseclient-py~=1.9.0", "httpx-sse~=0.4.0", "sendgrid~=6.12.3", @@ -130,7 +128,6 @@ dev = [ "types-defusedxml~=0.7.0", "types-deprecated~=1.3.1", "types-docutils~=0.22.3", - "types-jsonschema~=4.26.0", "types-flask-cors~=6.0.0", "types-flask-migrate~=4.1.0", "types-gevent~=25.9.0", @@ -150,7 +147,7 @@ dev = [ "types-python-dateutil~=2.9.0", "types-pywin32~=311.0.0", "types-pyyaml~=6.0.12", - "types-regex~=2026.2.28", + "types-regex~=2026.3.32", "types-shapely~=2.1.0", "types-simplejson>=3.20.0", "types-six>=1.17.0", @@ -206,7 +203,7 @@ vdb = [ "alibabacloud_gpdb20160503~=5.1.0", "alibabacloud_tea_openapi~=0.4.3", "chromadb==0.5.20", - "clickhouse-connect~=0.14.1", + "clickhouse-connect~=0.15.0", "clickzetta-connector-python>=0.8.102", "couchbase~=4.5.0", "elasticsearch==8.14.0", @@ -219,38 +216,18 @@ vdb = [ "pyobvector~=0.2.17", "qdrant-client==1.9.0", "intersystems-irispython>=5.1.0", - "tablestore==6.4.1", - "tcvectordb~=2.0.0", + "tablestore==6.4.2", + "tcvectordb~=2.1.0", "tidb-vector==0.0.15", "upstash-vector==0.8.0", "volcengine-compat~=1.0.0", "weaviate-client==4.20.4", - "xinference-client~=2.3.1", + "xinference-client~=2.4.0", "mo-vector~=0.1.13", "mysql-connector-python>=9.3.0", "holo-search-sdk>=0.4.1", ] -[tool.mypy] - -[[tool.mypy.overrides]] -# targeted ignores for current type-check errors -# TODO(QuantumGhost): suppress type errors in HITL related code. -# fix the type error later -module = [ - "configs.middleware.cache.redis_pubsub_config", - "extensions.ext_redis", - "tasks.workflow_execution_tasks", - "graphon.nodes.base.node", - "services.human_input_delivery_test_service", - "core.app.apps.advanced_chat.app_generator", - "controllers.console.human_input_form", - "controllers.console.app.workflow_run", - "repositories.sqlalchemy_api_workflow_node_execution_repository", - "extensions.logstore.repositories.logstore_api_workflow_run_repository", -] -ignore_errors = true - [tool.pyrefly] project-includes = ["."] project-excludes = [".venv", "migrations/"] diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index dc0adbf50d..43f604c2de 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -109,34 +109,16 @@ core/trigger/debug/event_selectors.py core/trigger/entities/entities.py core/trigger/provider.py core/workflow/workflow_entry.py -graphon/entities/workflow_execution.py -graphon/file/file_manager.py -graphon/graph_engine/error_handler.py -graphon/graph_engine/layers/execution_limits.py -graphon/nodes/agent/agent_node.py -graphon/nodes/base/node.py -graphon/nodes/code/code_node.py -graphon/nodes/datasource/datasource_node.py -graphon/nodes/document_extractor/node.py -graphon/nodes/human_input/human_input_node.py -graphon/nodes/if_else/if_else_node.py -graphon/nodes/iteration/iteration_node.py -graphon/nodes/knowledge_index/knowledge_index_node.py +enterprise/telemetry/contracts.py +enterprise/telemetry/draft_trace.py +enterprise/telemetry/enterprise_trace.py +enterprise/telemetry/entities/__init__.py +enterprise/telemetry/event_handlers.py +enterprise/telemetry/exporter.py +enterprise/telemetry/id_generator.py +enterprise/telemetry/metric_handler.py +enterprise/telemetry/telemetry_log.py core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py -graphon/nodes/list_operator/node.py -graphon/nodes/llm/node.py -graphon/nodes/loop/loop_node.py -graphon/nodes/parameter_extractor/parameter_extractor_node.py -graphon/nodes/question_classifier/question_classifier_node.py -graphon/nodes/start/start_node.py -graphon/nodes/template_transform/template_transform_node.py -graphon/nodes/tool/tool_node.py -graphon/nodes/trigger_plugin/trigger_event_node.py -graphon/nodes/trigger_schedule/trigger_schedule_node.py -graphon/nodes/trigger_webhook/node.py -graphon/nodes/variable_aggregator/variable_aggregator_node.py -graphon/nodes/variable_assigner/v1/node.py -graphon/nodes/variable_assigner/v2/node.py extensions/logstore/repositories/logstore_api_workflow_run_repository.py extensions/otel/instrumentation.py extensions/otel/runtime.py diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index ffc17e92cf..1a2a539c80 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -38,11 +38,11 @@ from collections.abc import Callable, Sequence from datetime import datetime from typing import Protocol +from graphon.entities.pause_reason import PauseReason +from graphon.enums import WorkflowType from sqlalchemy.orm import Session from core.repositories.factory import WorkflowExecutionRepository -from graphon.entities.pause_reason import PauseReason -from graphon.enums import WorkflowType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 44735eb769..d5c6a203b1 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -10,11 +10,11 @@ from collections.abc import Sequence from datetime import datetime from typing import Protocol, cast +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import ( DifyAPIWorkflowNodeExecutionRepository, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 5bb0c74ada..413936b542 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -28,15 +28,15 @@ from decimal import Decimal from typing import Any, cast import sqlalchemy as sa +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.nodes.human_input.entities import FormDefinition from pydantic import ValidationError from sqlalchemy import and_, delete, func, null, or_, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.nodes.human_input.entities import FormDefinition from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date from libs.infinite_scroll_pagination import InfiniteScrollPagination diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py index 67f8795d3f..feba5f7eb6 100644 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -7,6 +7,9 @@ from collections import defaultdict from collections.abc import Sequence from typing import Any +from graphon.nodes.human_input.entities import FormDefinition +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode from sqlalchemy import select from sqlalchemy.orm import Session, selectinload, sessionmaker @@ -18,9 +21,6 @@ from core.entities.execution_extra_content import ( from core.entities.execution_extra_content import ( HumanInputContent as HumanInputContentDomainModel, ) -from graphon.nodes.human_input.entities import FormDefinition -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode from models.execution_extra_content import ( ExecutionExtraContent as ExecutionExtraContentModel, ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 643a2a2a84..dd73e10374 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -11,6 +11,12 @@ from uuid import uuid4 import yaml from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from packaging import version from packaging.version import parse as parse_version from pydantic import BaseModel, Field @@ -30,12 +36,6 @@ from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerSc from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode from models.model import AppModelConfig, AppModelConfigDict, IconType diff --git a/api/services/app_service.py b/api/services/app_service.py index a9ec357455..e9aeb6c43d 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -4,6 +4,8 @@ from typing import Any, TypedDict, cast import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from configs import dify_config from constants.model_template import default_app_templates @@ -12,10 +14,8 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from events.app_event import app_was_created +from events.app_event import app_was_created, app_was_deleted, app_was_updated from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account @@ -281,6 +281,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_name(self, app: App, name: str) -> App: @@ -296,6 +298,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: @@ -313,6 +317,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_site_status(self, app: App, enable_site: bool) -> App: @@ -330,6 +336,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_api_status(self, app: App, enable_api: bool) -> App: @@ -348,6 +356,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def delete_app(self, app: App): @@ -355,6 +365,8 @@ class AppService: Delete app :param app: App instance """ + app_was_deleted.send(app) + db.session.delete(app) db.session.commit() diff --git a/api/services/app_task_service.py b/api/services/app_task_service.py index 6e9d6b1c73..0842e9d3e7 100644 --- a/api/services/app_task_service.py +++ b/api/services/app_task_service.py @@ -5,10 +5,11 @@ like stopping tasks, handling both legacy Redis flag mechanism and new GraphEngine command channel mechanism. """ +from graphon.graph_engine.manager import GraphEngineManager + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_redis import redis_client -from graphon.graph_engine.manager import GraphEngineManager from models.model import AppMode diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 9e743bf7b1..90e72d5f34 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,12 +5,12 @@ from collections.abc import Generator from typing import cast from flask import Response, stream_with_context +from graphon.model_runtime.entities.model_entities import ModelType from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models.enums import MessageStatus from models.model import App, AppMode, Message from services.errors.audio import ( diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index 56aaf407ee..3282dcfb11 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -35,15 +35,13 @@ class ApiKeyAuthService: @staticmethod def get_auth_credentials(tenant_id: str, category: str, provider: str): - data_source_api_key_bindings = ( - db.session.query(DataSourceApiKeyAuthBinding) - .where( + data_source_api_key_bindings = db.session.scalar( + select(DataSourceApiKeyAuthBinding).where( DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.category == category, DataSourceApiKeyAuthBinding.provider == provider, DataSourceApiKeyAuthBinding.disabled.is_(False), ) - .first() ) if not data_source_api_key_bindings: return None @@ -54,10 +52,11 @@ class ApiKeyAuthService: @staticmethod def delete_provider_auth(tenant_id: str, binding_id: str): - data_source_api_key_binding = ( - db.session.query(DataSourceApiKeyAuthBinding) - .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) - .first() + data_source_api_key_binding = db.session.scalar( + select(DataSourceApiKeyAuthBinding).where( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, + DataSourceApiKeyAuthBinding.id == binding_id, + ) ) if data_source_api_key_binding: db.session.delete(data_source_api_key_binding) diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index c6b32b373e..1c128524ad 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app +from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -13,7 +14,6 @@ from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_storage import storage -from graphon.model_runtime.utils.encoders import jsonable_encoder from models.account import Tenant from models.model import ( App, diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 545c5048d5..ba1e7bb826 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -3,6 +3,7 @@ import logging from collections.abc import Callable, Sequence from typing import Any, Union +from graphon.variables.types import SegmentType from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session @@ -12,7 +13,6 @@ from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db from factories import variable_factory -from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index 287d513f48..95a8951951 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ +from graphon.variables.variables import VariableBase from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker -from graphon.variables.variables import VariableBase from models import ConversationVariable diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 3e2342b1a7..83363125c3 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -10,6 +10,9 @@ from collections.abc import Sequence from typing import Any, Literal, cast import sqlalchemy as sa +from graphon.file import helpers as file_helpers +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from redis.exceptions import LockNotOwnedError from sqlalchemy import exists, func, select from sqlalchemy.orm import Session @@ -28,9 +31,6 @@ from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.file import helpers as file_helpers -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 2b7bebb01e..06f83a18f7 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -3,6 +3,7 @@ import time from collections.abc import Mapping from typing import Any +from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy.orm import Session from configs import dify_config @@ -16,7 +17,6 @@ from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.provider_entities import FormType from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 6679c08ebd..a944ef6acd 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,6 +1,15 @@ from collections.abc import Sequence from enum import StrEnum +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderHelpEntity, + SimpleProviderEntity, +) from pydantic import BaseModel, ConfigDict, model_validator from configs import dify_config @@ -15,15 +24,6 @@ from core.entities.provider_entities import ( QuotaConfiguration, UnaddedModelConfiguration, ) -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderHelpEntity, - SimpleProviderEntity, -) from models.provider import ProviderType diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index d2fa98f5e2..64852c222f 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -4,13 +4,13 @@ from typing import Any, Union, cast from urllib.parse import urlparse import httpx +from graphon.nodes.http_request.exc import InvalidHttpMethodError from sqlalchemy import select from constants import HIDDEN_VALUE from core.helper import ssrf_proxy from core.rag.entities.metadata_entities import MetadataCondition from extensions.ext_database import db -from graphon.nodes.http_request.exc import InvalidHttpMethodError from libs.datetime_utils import naive_utc_now from models.dataset import ( Dataset, diff --git a/api/services/file_service.py b/api/services/file_service.py index c11f018f52..50a326d813 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -8,6 +8,7 @@ from tempfile import NamedTemporaryFile from typing import Literal, Union from zipfile import ZIP_DEFLATED, ZipFile +from graphon.file import helpers as file_helpers from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import NotFound @@ -23,7 +24,6 @@ from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType -from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from models import Account diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index d490ad1561..82e0b0f8b1 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -3,6 +3,8 @@ import logging import time from typing import Any +from graphon.model_runtime.entities import LLMMode + from core.app.app_config.entities import ModelConfig from core.rag.datasource.retrieval_service import RetrievalService from core.rag.index_processor.constant.query_type import QueryType @@ -10,7 +12,6 @@ from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db -from graphon.model_runtime.entities import LLMMode from models import Account from models.dataset import Dataset, DatasetQuery from models.enums import CreatorUserRole, DatasetQuerySource diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 861d952c93..77576fa4c0 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from enum import StrEnum from typing import Protocol +from graphon.runtime import VariablePool from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker @@ -17,7 +18,6 @@ from core.workflow.human_input_compat import ( ) from extensions.ext_database import db from extensions.ext_mail import mail -from graphon.runtime import VariablePool from libs.email_template_renderer import render_email_template from models import Account, TenantAccountJoin from services.feature_service import FeatureService @@ -220,7 +220,7 @@ class EmailDeliveryTestHandler: stmt = stmt.where(Account.id.in_(unique_ids)) with self._session_factory() as session: - rows = session.execute(stmt).all() + rows = session.execute(stmt).tuples().all() return dict(rows) @staticmethod diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 76598d31ac..02a6620fc7 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -3,6 +3,12 @@ from collections.abc import Mapping from datetime import datetime, timedelta from typing import Any +from graphon.nodes.human_input.entities import ( + FormDefinition, + HumanInputSubmissionValidationError, + validate_human_input_submission, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -11,12 +17,6 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from graphon.nodes.human_input.entities import ( - FormDefinition, - HumanInputSubmissionValidationError, - validate_human_input_submission, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.exception import BaseHTTPException from models.human_input import RecipientType diff --git a/api/services/message_service.py b/api/services/message_service.py index 0c4a334b47..e5389ef659 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -2,6 +2,7 @@ import json from collections.abc import Sequence from typing import Union +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import sessionmaker from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -13,7 +14,6 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.enums import FeedbackFromSource, FeedbackRating diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 469357d6e0..91cca5cb6d 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -3,6 +3,12 @@ import logging from json import JSONDecodeError from typing import Union +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ModelCredentialSchema, + ProviderCredentialSchema, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from sqlalchemy import or_, select from constants import HIDDEN_VALUE @@ -13,12 +19,6 @@ from core.model_manager import LBModelManager from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.provider_manager import ProviderManager from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ModelCredentialSchema, - ProviderCredentialSchema, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.enums import CredentialSourceType from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index e634f90603..3f37c9b176 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,9 +1,10 @@ import logging +from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule + from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager from core.provider_manager import ProviderManager -from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule from models.provider import ProviderType from services.entities.model_provider_entities import ( CustomConfigurationResponse, diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 8a28537528..bcf5973d7b 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -9,6 +9,15 @@ from typing import Any, Union, cast from uuid import uuid4 from flask_login import current_user +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from graphon.errors import WorkflowNodeRunFailedError +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable, VariableBase from sqlalchemy import func, select from sqlalchemy.orm import Session, sessionmaker @@ -46,20 +55,8 @@ from core.workflow.system_variables import ( ) from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry +from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from extensions.ext_database import db -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType -from graphon.errors import WorkflowNodeRunFailedError -from graphon.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent -from graphon.graph_events.base import GraphNodeEventBase -from graphon.node_events.base import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from graphon.runtime import VariablePool -from graphon.variables.variables import Variable, VariableBase from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.dataset import ( # type: ignore @@ -577,6 +574,13 @@ class RagPipelineService: outputs=workflow_node_execution.outputs, ) session.commit() + if workflow_node_execution_db_model is not None: + enqueue_draft_node_execution_trace( + execution=workflow_node_execution_db_model, + outputs=workflow_node_execution.outputs, + workflow_execution_id=None, + user_id=account.id, + ) return workflow_node_execution_db_model def run_datasource_workflow_node( @@ -1339,6 +1343,12 @@ class RagPipelineService: outputs=workflow_node_execution.outputs, ) session.commit() + enqueue_draft_node_execution_trace( + execution=workflow_node_execution_db_model, + outputs=workflow_node_execution.outputs, + workflow_execution_id=None, + user_id=current_user.id, + ) return workflow_node_execution_db_model def get_recommended_plugins(self, type: str) -> dict: diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 1b8207cc31..04156713f4 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -14,6 +14,12 @@ import yaml # type: ignore from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from flask_login import current_user +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from packaging import version from pydantic import BaseModel, Field from sqlalchemy import select @@ -28,12 +34,6 @@ from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_redis import redis_client from factories import variable_factory -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.enums import CollectionBindingType, DatasetRuntimeMode diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py index c91f621ffb..2c1f99a3bc 100644 --- a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -27,13 +27,13 @@ from dataclasses import dataclass, field from typing import Any import click +from graphon.enums import WorkflowType from sqlalchemy import inspect from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db -from graphon.enums import WorkflowType from libs.archive_storage import ( ArchiveStorage, ArchiveStorageNotConfiguredError, diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 4334412c8b..12053377e2 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -6,6 +6,8 @@ import uuid from datetime import UTC, datetime from typing import Any +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.db.session_factory import session_factory @@ -15,8 +17,6 @@ from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 9190a67249..2a56bc0c71 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -2,6 +2,7 @@ import json import logging from typing import Any, cast +from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx import get from sqlalchemy import select from typing_extensions import TypedDict @@ -21,7 +22,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db -from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider from services.tools.tools_transform_service import ToolTransformService diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 6797a67dde..8e3c36e099 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -275,7 +275,7 @@ class BuiltinToolManageService: user_id=user_id, provider=provider, encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), - credential_type=api_type.value, + credential_type=api_type, name=name, expires_at=expires_at if expires_at is not None else -1, ) @@ -314,7 +314,7 @@ class BuiltinToolManageService: .filter_by( tenant_id=tenant_id, provider=provider, - credential_type=credential_type.value, + credential_type=credential_type, ) .order_by(BuiltinToolProvider.created_at.desc()) .all() diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index b6e5367023..b276146066 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -423,7 +423,7 @@ class ToolTransformService: id=provider.id, name=provider.name, provider=provider.provider, - credential_type=CredentialType.of(provider.credential_type), + credential_type=provider.credential_type, is_default=provider.is_default, credentials=credentials, ) diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 931ca5021a..fb6b5bea24 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -2,6 +2,7 @@ import json import logging from datetime import datetime +from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import or_, select from sqlalchemy.orm import Session @@ -13,7 +14,6 @@ from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurati from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db -from graphon.model_runtime.utils.encoders import jsonable_encoder from models.model import App from models.tools import WorkflowToolProvider from models.workflow import Workflow diff --git a/api/services/trigger/schedule_service.py b/api/services/trigger/schedule_service.py index a827222c1d..25e80770b8 100644 --- a/api/services/trigger/schedule_service.py +++ b/api/services/trigger/schedule_service.py @@ -2,6 +2,7 @@ import json import logging from datetime import datetime +from graphon.entities.graph_config import NodeConfigDict from sqlalchemy import select from sqlalchemy.orm import Session @@ -13,7 +14,6 @@ from core.workflow.nodes.trigger_schedule.entities import ( VisualConfig, ) from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError -from graphon.entities.graph_config import NodeConfigDict from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin from models.trigger import WorkflowSchedulePlan diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 688993c798..008d8bdb8a 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -198,7 +198,7 @@ class TriggerProviderService: credentials=dict(credential_encrypter.encrypt(dict(credentials))) if credential_encrypter else {}, - credential_type=credential_type.value, + credential_type=credential_type, credential_expires_at=credential_expires_at, expires_at=expires_at, ) diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index dca00a466b..d72c041609 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -5,6 +5,7 @@ from collections.abc import Mapping from typing import Any from flask import Request, Response +from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -20,7 +21,6 @@ from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_ from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from extensions.ext_database import db from extensions.ext_redis import redis_client -from graphon.entities.graph_config import NodeConfigDict from models.model import App from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription, WorkflowPluginTrigger diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 5d9be84c06..c03275497d 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -7,6 +7,9 @@ from typing import Any import orjson from flask import request +from graphon.entities.graph_config import NodeConfigDict +from graphon.file import FileTransferMethod +from graphon.variables.types import ArrayValidation, SegmentType from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -28,9 +31,6 @@ from enums.quota_type import QuotaType from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory -from graphon.entities.graph_config import NodeConfigDict -from graphon.file.models import FileTransferMethod -from graphon.variables.types import ArrayValidation, SegmentType from models.enums import AppTriggerStatus, AppTriggerType from models.model import App from models.trigger import AppTrigger, WorkflowWebhookTrigger diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index d0a4317065..62916cc2c9 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -5,8 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from typing import Any, Generic, TypeAlias, TypeVar, overload -from configs import dify_config -from graphon.file.models import File +from graphon.file import File from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable from graphon.variables.segments import ( ArrayFileSegment, @@ -22,6 +21,8 @@ from graphon.variables.segments import ( ) from graphon.variables.utils import dumps_with_segments +from configs import dify_config + _MAX_DEPTH = 100 diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 5fd310b689..3f78b823a6 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,5 +1,7 @@ import logging +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelInstance, ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector @@ -9,7 +11,6 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from graphon.model_runtime.entities.model_entities import ModelType from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 1f3993505c..31367f72fa 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,6 +1,11 @@ import json from typing import Any +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.input_entities import VariableEntity from typing_extensions import TypedDict from core.app.app_config.entities import ( @@ -19,11 +24,6 @@ from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.app_event import app_was_created from extensions.ext_database import db -from graphon.file.models import FileUploadConfig -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.input_entities import VariableEntity from models import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig, IconType diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index fa26f507ee..bf178e8a44 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -3,11 +3,11 @@ import uuid from datetime import datetime from typing import Any +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from typing_extensions import TypedDict -from graphon.enums import WorkflowExecutionStatus from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 0b5c89e574..98e338a2d4 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -6,6 +6,19 @@ from concurrent.futures import ThreadPoolExecutor from enum import StrEnum from typing import Any, ClassVar +from graphon.enums import NodeType +from graphon.file import File +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.variable_assigner.common.helpers import get_updated_variables +from graphon.variable_loader import VariableLoader +from graphon.variables import Segment, StringSegment, VariableBase +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.segments import ( + ArrayFileSegment, + FileSegment, +) +from graphon.variables.types import SegmentType +from graphon.variables.utils import dumps_with_segments from sqlalchemy import Engine, orm, select from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -26,19 +39,6 @@ from core.workflow.variable_prefixes import ( from extensions.ext_storage import storage from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable -from graphon.enums import NodeType -from graphon.file.models import File -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.variable_assigner.common.helpers import get_updated_variables -from graphon.variable_loader import VariableLoader -from graphon.variables import Segment, StringSegment, VariableBase -from graphon.variables.consts import SELECTORS_LENGTH -from graphon.variables.segments import ( - ArrayFileSegment, - FileSegment, -) -from graphon.variables.types import SegmentType -from graphon.variables.utils import dumps_with_segments from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models import Account, App, Conversation diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 5fca444723..601e9261fc 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -9,6 +9,10 @@ from collections.abc import Generator, Mapping, Sequence from dataclasses import dataclass from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import desc, select from sqlalchemy.orm import Session, sessionmaker @@ -22,10 +26,6 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models.model import AppMode, Message from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 785f6f108c..b555676704 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,33 +5,6 @@ import uuid from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast -from sqlalchemy import exists, select -from sqlalchemy.orm import Session, sessionmaker - -from configs import dify_config -from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager -from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context -from core.app.file_access import DatabaseFileAccessController -from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager -from core.repositories import DifyCoreRepositoryFactory -from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.trigger.constants import is_trigger_node_type -from core.workflow.human_input_compat import ( - DeliveryChannelConfig, - normalize_human_input_node_data_for_graph, - parse_human_input_delivery_methods, -) -from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type -from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool -from core.workflow.workflow_entry import WorkflowEntry -from enums.cloud_plan import CloudPlan -from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated -from extensions.ext_database import db -from extensions.ext_storage import storage -from factories.file_factory import build_from_mapping, build_from_mappings from graphon.entities import GraphInitParams, WorkflowNodeExecution from graphon.entities.graph_config import NodeConfigDict from graphon.entities.pause_reason import HumanInputRequired @@ -57,6 +30,34 @@ from graphon.variable_loader import load_into_variable_pool from graphon.variables import VariableBase from graphon.variables.input_entities import VariableEntityType from graphon.variables.variables import Variable +from sqlalchemy import exists, select +from sqlalchemy.orm import Session, sessionmaker + +from configs import dify_config +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager +from core.repositories import DifyCoreRepositoryFactory +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl +from core.trigger.constants import is_trigger_node_type +from core.workflow.human_input_compat import ( + DeliveryChannelConfig, + normalize_human_input_node_data_for_graph, + parse_human_input_delivery_methods, +) +from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type +from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.workflow_entry import WorkflowEntry +from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace +from enums.cloud_plan import CloudPlan +from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated +from extensions.ext_database import db +from extensions.ext_storage import storage +from factories.file_factory import build_from_mapping, build_from_mappings from libs.datetime_utils import naive_utc_now from models import Account from models.human_input import HumanInputFormRecipient, RecipientType @@ -849,6 +850,13 @@ class WorkflowService: draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs) session.commit() + enqueue_draft_node_execution_trace( + execution=workflow_node_execution, + outputs=outputs, + workflow_execution_id=None, + user_id=account.id, + ) + return workflow_node_execution def get_human_input_form_preview( diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 458099d99e..489467651d 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -7,6 +7,7 @@ from typing import Annotated, Any, TypeAlias, Union from celery import shared_task from flask import current_app, json +from graphon.runtime import GraphRuntimeState from pydantic import BaseModel, Discriminator, Field, Tag from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -22,7 +23,6 @@ from core.app.entities.app_invoke_entities import ( from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.runtime import GraphRuntimeState from libs.flask_utils import set_login_user from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index 6365400dd1..0a73c91279 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -10,6 +10,7 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task +from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -22,7 +23,6 @@ from core.app.layers.trigger_post_layer import TriggerPostLayer from core.db.session_factory import session_factory from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db -from graphon.runtime import GraphRuntimeState from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus from models.model import App, EndUser, Tenant diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index ed8a24b336..20335d9b9f 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -7,6 +7,7 @@ from pathlib import Path import click import pandas as pd from celery import shared_task +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import func from core.db.session_factory import session_factory @@ -14,7 +15,6 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from extensions.ext_storage import storage -from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment diff --git a/api/tasks/enterprise_telemetry_task.py b/api/tasks/enterprise_telemetry_task.py new file mode 100644 index 0000000000..7d5ea7c0a5 --- /dev/null +++ b/api/tasks/enterprise_telemetry_task.py @@ -0,0 +1,52 @@ +"""Celery worker for enterprise metric/log telemetry events. + +This module defines the Celery task that processes telemetry envelopes +from the enterprise_telemetry queue. It deserializes envelopes and +dispatches them to the EnterpriseMetricHandler. +""" + +import json +import logging + +from celery import shared_task + +from enterprise.telemetry.contracts import TelemetryEnvelope +from enterprise.telemetry.metric_handler import EnterpriseMetricHandler + +logger = logging.getLogger(__name__) + + +@shared_task(queue="enterprise_telemetry") +def process_enterprise_telemetry(envelope_json: str) -> None: + """Process enterprise metric/log telemetry envelope. + + This task is enqueued by the TelemetryGateway for metric/log-only + events. It deserializes the envelope and dispatches to the handler. + + Best-effort processing: logs errors but never raises, to avoid + failing user requests due to telemetry issues. + + Args: + envelope_json: JSON-serialized TelemetryEnvelope. + """ + try: + # Deserialize envelope + envelope_dict = json.loads(envelope_json) + envelope = TelemetryEnvelope.model_validate(envelope_dict) + + # Process through handler + handler = EnterpriseMetricHandler() + handler.handle(envelope) + + logger.debug( + "Successfully processed telemetry envelope: tenant_id=%s, event_id=%s, case=%s", + envelope.tenant_id, + envelope.event_id, + envelope.case, + ) + except Exception: + # Best-effort: log and drop on error, never fail user request + logger.warning( + "Failed to process enterprise telemetry envelope, dropping event", + exc_info=True, + ) diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index fd743205a1..ca73b4d374 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -2,6 +2,8 @@ import logging from datetime import timedelta from celery import shared_task +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import or_, select from sqlalchemy.orm import sessionmaker @@ -9,8 +11,6 @@ from configs import dify_config from core.repositories.human_input_repository import HumanInputFormSubmissionRepository from extensions.ext_database import db from extensions.ext_storage import storage -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from models.human_input import HumanInputForm from models.workflow import WorkflowPause, WorkflowRun diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index f8ae3f4b6e..a316eec7b9 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -6,6 +6,7 @@ from typing import Any import click from celery import shared_task +from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -14,7 +15,6 @@ from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod from extensions.ext_database import db from extensions.ext_mail import mail -from graphon.runtime import GraphRuntimeState, VariablePool from models.human_input import ( DeliveryMethodType, HumanInputDelivery, diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 72e3b42ca7..c95b8db078 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -39,17 +39,36 @@ def process_trace_tasks(file_info): trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]] try: + trace_type = trace_info_info_map.get(trace_info_type) + if trace_type: + trace_info = trace_type(**trace_info) + + from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled + + if is_ee_telemetry_enabled(): + from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace + + try: + EnterpriseOtelTrace().trace(trace_info) + except Exception: + logger.exception("Enterprise trace failed for app_id: %s", app_id) + if trace_instance: with current_app.app_context(): - trace_type = trace_info_info_map.get(trace_info_type) - if trace_type: - trace_info = trace_type(**trace_info) trace_instance.trace(trace_info) + logger.info("Processing trace tasks success, app_id: %s", app_id) except Exception as e: - logger.info("error:\n\n\n%s\n\n\n\n", e) + logger.exception("Processing trace tasks failed, app_id: %s", app_id) failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}" redis_client.incr(failed_key) - logger.info("Processing trace tasks failed, app_id: %s", app_id) finally: - storage.delete(file_path) + try: + storage.delete(file_path) + except Exception as e: + logger.warning( + "Failed to delete trace file %s for app_id %s: %s", + file_path, + app_id, + e, + ) diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 25ea53dfac..56626e372e 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -12,6 +12,7 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -28,7 +29,6 @@ from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from enums.quota_type import QuotaType, unlimited -from graphon.enums import WorkflowExecutionStatus from models.enums import ( AppTriggerType, CreatorUserRole, diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index ae1c2991c9..0c7f74c180 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -9,11 +9,11 @@ import json import logging from celery import shared_task +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from core.db.session_factory import session_factory -from graphon.entities.workflow_execution import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index a0fd739325..f25ebe3bae 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -9,13 +9,13 @@ import json import logging from celery import shared_task -from sqlalchemy import select - -from core.db.session_factory import session_factory from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, ) from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter +from sqlalchemy import select + +from core.db.session_factory import session_factory from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -125,7 +125,7 @@ def _create_node_execution_from_domain( else: node_execution.execution_metadata = "{}" - node_execution.status = execution.status.value + node_execution.status = execution.status node_execution.error = execution.error node_execution.elapsed_time = execution.elapsed_time node_execution.created_by_role = creator_user_role @@ -159,7 +159,7 @@ def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionMode node_execution.execution_metadata = "{}" # Update other fields - node_execution.status = execution.status.value + node_execution.status = execution.status node_execution.error = execution.error node_execution.elapsed_time = execution.elapsed_time node_execution.finished_at = execution.finished_at diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py index a876b0c4aa..91245e879e 100644 --- a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -1,8 +1,9 @@ from collections.abc import Generator +from graphon.node_events import StreamCompletedEvent + from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage -from graphon.node_events import StreamCompletedEvent def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index b2de11b068..3fdea10976 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,7 +1,8 @@ +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamCompletedEvent + from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamCompletedEvent class _Seg: diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index 878d9b24df..c1bb8e1245 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -4,13 +4,13 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader -from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index c4146d5ccd..ce04a158a8 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -4,9 +4,6 @@ from collections.abc import Generator, Sequence from decimal import Decimal from json import dumps -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from core.plugin.impl.model import PluginModelClient - # import monkeypatch from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.llm_entities import ( @@ -26,6 +23,9 @@ from graphon.model_runtime.entities.model_entities import ( ) from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.model import PluginModelClient + class MockModelClass(PluginModelClient): def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 0b21ff1d2a..5c6636f31e 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -3,6 +3,10 @@ import unittest import uuid import pytest +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType +from graphon.variables.variables import StringVariable from sqlalchemy import delete from sqlalchemy.orm import Session @@ -11,10 +15,6 @@ from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType from factories.variable_factory import build_segment -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType -from graphon.variables.variables import StringVariable from libs import datetime_utils from models.enums import CreatorUserRole from models.model import UploadFile diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index f6f4cf260b..38dc8bbb28 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,11 +2,11 @@ import uuid from unittest.mock import patch import pytest +from graphon.variables.segments import StringSegment from sqlalchemy import delete from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType -from graphon.variables.segments import StringSegment from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -193,6 +193,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def setup_offload_test_data(self, app_and_tenant): tenant, app = app_and_tenant from graphon.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: @@ -424,6 +425,7 @@ class TestDeleteDraftVariablesSessionCommit: def setup_offload_test_data(self, app_and_tenant): """Create test data with offload files for session commit tests.""" from graphon.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now tenant, app = app_and_tenant diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index a9a2617bae..c0143faa85 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -1,11 +1,12 @@ from unittest.mock import MagicMock +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory -from graphon.model_runtime.entities.model_entities import ModelType from models.provider import ProviderType diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 7573e00872..ce0c8bf8ca 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -2,17 +2,17 @@ import time import uuid import pytest - -from configs import dify_config -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.node_events import NodeRunResult from graphon.nodes.code.code_node import CodeNode from graphon.nodes.code.limits import CodeNodeLimits from graphon.runtime import GraphRuntimeState, VariablePool + +from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -172,7 +172,7 @@ def test_execute_code_output_validator(setup_code_executor_mock): result = node._run() assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.FAILED - assert result.error == "Output result must be a string, got int instead" + assert result.error == "Output result must be a string, got int instead." def test_execute_code_output_validator_depth(): diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 17ea7de881..ce18486faf 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -3,6 +3,11 @@ import uuid from urllib.parse import urlencode import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.graph import Graph +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig +from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom @@ -11,11 +16,6 @@ from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.graph import Graph -from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -191,7 +191,6 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" - from core.workflow.system_variables import build_system_variables from graphon.enums import BuiltinNodeTypes from graphon.nodes.http_request.entities import ( HttpRequestNodeAuthorization, @@ -202,6 +201,8 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool + from core.workflow.system_variables import build_system_variables + # Create variable pool variable_pool = VariablePool( system_variables=build_system_variables(user_id="test", files=[]), diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index fa5d63cfbf..f0f3fcead1 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,11 +4,6 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.llm_generator.output_parser.structured_output import _parse_structured_output -from core.model_manager import ModelInstance -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import StreamCompletedEvent from graphon.nodes.llm.file_saver import LLMFileSaver @@ -17,6 +12,12 @@ from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.nodes.protocols import HttpClientProtocol from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.llm_generator.output_parser.structured_output import _parse_structured_output +from core.model_manager import ModelInstance +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 367b5bbc11..3bf44df349 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,16 +3,17 @@ import time import uuid from unittest.mock import MagicMock +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_runtime import DifyPromptMessageSerializer from core.workflow.system_variables import build_system_variables from extensions.ext_database import db -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory -from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 9e3e1a47e3..2d728569be 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,14 +1,15 @@ import time import uuid -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index f9ec51ee10..750ced7075 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -2,17 +2,18 @@ import time import uuid from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.tools.utils.configuration import ToolParameterConfigurationManager -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.node_runtime import DifyToolNodeRuntime -from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.node_events import StreamCompletedEvent from graphon.nodes.protocols import ToolFileManagerProtocol from graphon.nodes.tool.tool_node import ToolNode from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.tools.utils.configuration import ToolParameterConfigurationManager +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index 48bf3ca446..be8a1c6aab 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -32,6 +32,7 @@ from extensions.ext_database import db # Configure logging for test containers logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) +_TEST_SANDBOX_IMAGE = os.getenv("TEST_SANDBOX_IMAGE", "langgenius/dify-sandbox:0.2.12") DEFAULT_SANDBOX_TEST_IMAGE = "langgenius/dify-sandbox:0.2.14" SANDBOX_TEST_IMAGE_ENV = "DIFY_SANDBOX_TEST_IMAGE" diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 5b51510388..5cc458fe2e 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -4,11 +4,11 @@ import json import uuid from flask.testing import FlaskClient +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN -from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py index 290be87697..8ddf867370 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -3,12 +3,12 @@ import uuid from flask.testing import FlaskClient +from graphon.variables.segments import StringSegment from sqlalchemy import select from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from factories.variable_factory import segment_to_variable -from graphon.variables.segments import StringSegment from models import Workflow from models.model import AppMode from models.workflow import WorkflowDraftVariable diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index b8840c4ba8..2b4c1b59ab 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -22,6 +22,13 @@ import uuid from time import time import pytest +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import GraphRunPausedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session @@ -33,16 +40,6 @@ from core.app.layers.pause_state_persist_layer import ( ) from core.workflow.system_variables import build_system_variables from extensions.ext_storage import storage -from graphon.entities.pause_reason import SchedulingPause -from graphon.enums import WorkflowExecutionStatus -from graphon.graph_engine.entities.commands import GraphEngineCommand -from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from graphon.graph_events.graph import GraphRunPausedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime.graph_runtime_state import GraphRuntimeState -from graphon.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from graphon.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper -from graphon.runtime.variable_pool import VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel @@ -545,7 +542,7 @@ class TestPauseStatePersistenceLayerTestContainers: layer.initialize(graph_runtime_state, command_channel) # Import other event types - from graphon.graph_events.graph import ( + from graphon.graph_events import ( GraphRunFailedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index e0c58f0f5c..13caad799e 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -4,6 +4,7 @@ from __future__ import annotations from uuid import uuid4 +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from sqlalchemy import Engine, select from sqlalchemy.orm import Session @@ -17,7 +18,6 @@ from core.workflow.human_input_compat import ( MemberRecipient, WebAppDeliveryMethod, ) -from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.human_input import ( EmailExternalRecipientPayload, diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index ae8c0716a4..0a9b476afc 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -4,6 +4,18 @@ from datetime import timedelta from unittest.mock import MagicMock import pytest +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import delete, select from sqlalchemy.orm import Session @@ -15,18 +27,6 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.node_runtime import DifyHumanInputNodeRuntime from core.workflow.system_variables import build_system_variables -from graphon.enums import WorkflowType -from graphon.graph import Graph -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from graphon.graph_engine.graph_engine import GraphEngine -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models.account import Tenant, TenantAccountJoin, TenantAccountRole diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index 2e207ddc67..cc72dc1cf3 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -4,13 +4,13 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader -from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py index 0fd03813da..b745aed141 100644 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -1,11 +1,13 @@ from __future__ import annotations from dataclasses import dataclass -from datetime import datetime, timedelta +from datetime import timedelta from decimal import Decimal from uuid import uuid4 from graphon.nodes.human_input.entities import FormDefinition, UserAction + +from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin from models.enums import ConversationFromSource, InvokeFrom from models.execution_extra_content import HumanInputContent @@ -117,7 +119,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: inputs=[], user_actions=[UserAction(id=action_id, title=action_text)], rendered_content="Rendered block", - expiration_time=datetime.utcnow() + timedelta(days=1), + expiration_time=naive_utc_now() + timedelta(days=1), node_title=node_title, display_in_ui=True, ) @@ -129,7 +131,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: form_definition=form_definition.model_dump_json(), rendered_content="Rendered block", status=HumanInputFormStatus.SUBMITTED, - expiration_time=datetime.utcnow() + timedelta(days=1), + expiration_time=naive_utc_now() + timedelta(days=1), selected_action_id=action_id, ) db_session.add(form) diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py index 641399c7f9..a68b3a08c7 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -5,10 +5,10 @@ from __future__ import annotations from datetime import timedelta from uuid import uuid4 +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index cb00752b35..d28cfda159 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -8,15 +8,15 @@ from unittest.mock import Mock from uuid import uuid4 import pytest -from sqlalchemy import Engine, delete, select -from sqlalchemy.orm import Session, sessionmaker - -from extensions.ext_storage import storage from graphon.entities import WorkflowExecution from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType from graphon.enums import WorkflowExecutionStatus from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus +from sqlalchemy import Engine, delete, select +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import ( diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py index 3d4ec25150..7f44eb6ca3 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -7,16 +7,17 @@ from __future__ import annotations from collections.abc import Generator from dataclasses import dataclass -from datetime import datetime, timedelta +from datetime import timedelta from decimal import Decimal from uuid import uuid4 import pytest +from graphon.nodes.human_input.entities import FormDefinition, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker -from graphon.nodes.human_input.entities import FormDefinition, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus +from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ConversationFromSource, InvokeFrom from models.execution_extra_content import ExecutionExtraContent, HumanInputContent @@ -174,7 +175,7 @@ def _create_submitted_form( action_title: str = "Approve", node_title: str = "Approval", ) -> HumanInputForm: - expiration_time = datetime.utcnow() + timedelta(days=1) + expiration_time = naive_utc_now() + timedelta(days=1) form_definition = FormDefinition( form_content="content", inputs=[], @@ -207,7 +208,7 @@ def _create_waiting_form( workflow_run_id: str, default_values: dict | None = None, ) -> HumanInputForm: - expiration_time = datetime.utcnow() + timedelta(days=1) + expiration_time = naive_utc_now() + timedelta(days=1) form_definition = FormDefinition( form_content="content", inputs=[], diff --git a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py index d6f0657380..c5e9201ee3 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py @@ -7,12 +7,12 @@ from datetime import timedelta from uuid import uuid4 import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy import exc as sa_exc from sqlalchemy.orm import Session, sessionmaker -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/graphon/model_runtime/__init__.py b/api/tests/test_containers_integration_tests/services/auth/__init__.py similarity index 100% rename from api/graphon/model_runtime/__init__.py rename to api/tests/test_containers_integration_tests/services/auth/__init__.py diff --git a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py new file mode 100644 index 0000000000..177fb95ff3 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import json +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest + +from models.source import DataSourceApiKeyAuthBinding +from services.auth.api_key_auth_service import ApiKeyAuthService + + +class TestApiKeyAuthService: + @pytest.fixture + def tenant_id(self) -> str: + return str(uuid4()) + + @pytest.fixture + def category(self) -> str: + return "search" + + @pytest.fixture + def provider(self) -> str: + return "google" + + @pytest.fixture + def mock_credentials(self) -> dict: + return {"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}} + + @pytest.fixture + def mock_args(self, category, provider, mock_credentials) -> dict: + return {"category": category, "provider": provider, "credentials": mock_credentials} + + def _create_binding(self, db_session, *, tenant_id, category, provider, credentials=None, disabled=False): + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant_id, + category=category, + provider=provider, + credentials=json.dumps(credentials, ensure_ascii=False) if credentials else None, + disabled=disabled, + ) + db_session.add(binding) + db_session.commit() + return binding + + def test_get_provider_auth_list_success( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + + assert len(result) >= 1 + tenant_results = [r for r in result if r.tenant_id == tenant_id] + assert len(tenant_results) == 1 + assert tenant_results[0].provider == provider + + def test_get_provider_auth_list_empty(self, flask_app_with_containers, db_session_with_containers, tenant_id): + result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + + tenant_results = [r for r in result if r.tenant_id == tenant_id] + assert tenant_results == [] + + def test_get_provider_auth_list_filters_disabled( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + self._create_binding( + db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider, disabled=True + ) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + + tenant_results = [r for r in result if r.tenant_id == tenant_id] + assert tenant_results == [] + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_success( + self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123" + + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + mock_factory.assert_called_once() + mock_auth_instance.validate_credentials.assert_called_once() + mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, "test_secret_key_123") + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all() + assert len(bindings) == 1 + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + def test_create_provider_auth_validation_failed( + self, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = False + mock_factory.return_value = mock_auth_instance + + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all() + assert len(bindings) == 0 + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_encrypts_api_key( + self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123" + + original_key = mock_args["credentials"]["config"]["api_key"] + + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + assert mock_args["credentials"]["config"]["api_key"] == "encrypted_test_key_123" + assert mock_args["credentials"]["config"]["api_key"] != original_key + mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, original_key) + + def test_get_auth_credentials_success( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider, mock_credentials + ): + self._create_binding( + db_session_with_containers, + tenant_id=tenant_id, + category=category, + provider=provider, + credentials=mock_credentials, + ) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + + assert result == mock_credentials + + def test_get_auth_credentials_not_found( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + + assert result is None + + def test_get_auth_credentials_json_parsing( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}} + self._create_binding( + db_session_with_containers, + tenant_id=tenant_id, + category=category, + provider=provider, + credentials=special_credentials, + ) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + + assert result == special_credentials + assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%" + + def test_delete_provider_auth_success( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + binding = self._create_binding( + db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider + ) + binding_id = binding.id + db_session_with_containers.expire_all() + + ApiKeyAuthService.delete_provider_auth(tenant_id, binding_id) + + db_session_with_containers.expire_all() + remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first() + assert remaining is None + + def test_delete_provider_auth_not_found(self, flask_app_with_containers, db_session_with_containers, tenant_id): + # Should not raise when binding not found + ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4())) + + def test_validate_api_key_auth_args_success(self, mock_args): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_category(self, mock_args): + del mock_args["category"] + with pytest.raises(ValueError, match="category is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_category(self, mock_args): + mock_args["category"] = "" + with pytest.raises(ValueError, match="category is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_provider(self, mock_args): + del mock_args["provider"] + with pytest.raises(ValueError, match="provider is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_provider(self, mock_args): + mock_args["provider"] = "" + with pytest.raises(ValueError, match="provider is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_credentials(self, mock_args): + del mock_args["credentials"] + with pytest.raises(ValueError, match="credentials is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_credentials(self, mock_args): + mock_args["credentials"] = None + with pytest.raises(ValueError, match="credentials is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_invalid_credentials_type(self, mock_args): + mock_args["credentials"] = "not_a_dict" + with pytest.raises(ValueError, match="credentials must be a dictionary"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_auth_type(self, mock_args): + del mock_args["credentials"]["auth_type"] + with pytest.raises(ValueError, match="auth_type is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_auth_type(self, mock_args): + mock_args["credentials"]["auth_type"] = "" + with pytest.raises(ValueError, match="auth_type is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + @pytest.mark.parametrize( + "malicious_input", + [ + "", + "'; DROP TABLE users; --", + "../../../etc/passwd", + "\\x00\\x00", + "A" * 10000, + ], + ) + def test_validate_api_key_auth_args_malicious_input(self, malicious_input, mock_args): + mock_args["category"] = malicious_input + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_database_error_handling( + self, mock_encrypter, mock_factory, flask_app_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.return_value = "encrypted_key" + + with patch("services.auth.api_key_auth_service.db.session") as mock_session: + mock_session.commit.side_effect = Exception("Database error") + with pytest.raises(Exception, match="Database error"): + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + def test_create_provider_auth_factory_exception(self, mock_factory, tenant_id, mock_args): + mock_factory.side_effect = Exception("Factory error") + with pytest.raises(Exception, match="Factory error"): + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, tenant_id, mock_args): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.side_effect = Exception("Encryption error") + with pytest.raises(Exception, match="Encryption error"): + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + def test_validate_api_key_auth_args_none_input(self): + with pytest.raises(TypeError): + ApiKeyAuthService.validate_api_key_auth_args(None) + + def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self, mock_args): + mock_args["credentials"]["auth_type"] = ["api_key"] + ApiKeyAuthService.validate_api_key_auth_args(mock_args) diff --git a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py new file mode 100644 index 0000000000..dc4c0fda1d --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py @@ -0,0 +1,264 @@ +""" +API Key Authentication System Integration Tests +""" + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import Mock, patch +from uuid import uuid4 + +import httpx +import pytest + +from models.source import DataSourceApiKeyAuthBinding +from services.auth.api_key_auth_factory import ApiKeyAuthFactory +from services.auth.api_key_auth_service import ApiKeyAuthService +from services.auth.auth_type import AuthType + + +class TestAuthIntegration: + @pytest.fixture + def tenant_id_1(self) -> str: + return str(uuid4()) + + @pytest.fixture + def tenant_id_2(self) -> str: + return str(uuid4()) + + @pytest.fixture + def category(self) -> str: + return "search" + + @pytest.fixture + def firecrawl_credentials(self) -> dict: + return {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}} + + @pytest.fixture + def jina_credentials(self) -> dict: + return {"auth_type": "bearer", "config": {"api_key": "jina_test_key_456"}} + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") + def test_end_to_end_auth_flow( + self, + mock_encrypt, + mock_http, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + mock_http.return_value = self._create_success_response() + mock_encrypt.return_value = "encrypted_fc_test_key_123" + + args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + + mock_http.assert_called_once() + call_args = mock_http.call_args + assert "https://api.firecrawl.dev/v1/crawl" in call_args[0][0] + assert call_args[1]["headers"]["Authorization"] == "Bearer fc_test_key_123" + + mock_encrypt.assert_called_once_with(tenant_id_1, "fc_test_key_123") + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id_1).all() + assert len(bindings) == 1 + assert bindings[0].provider == AuthType.FIRECRAWL + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_cross_component_integration(self, mock_http, firecrawl_credentials): + mock_http.return_value = self._create_success_response() + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, firecrawl_credentials) + result = factory.validate_credentials() + + assert result is True + mock_http.assert_called_once() + + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") + @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.jina.jina.httpx.post") + def test_multi_tenant_isolation( + self, + mock_jina_http, + mock_fc_http, + mock_encrypt, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + tenant_id_2, + category, + firecrawl_credentials, + jina_credentials, + ): + mock_fc_http.return_value = self._create_success_response() + mock_jina_http.return_value = self._create_success_response() + mock_encrypt.return_value = "encrypted_key" + + args1 = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_1, args1) + + args2 = {"category": category, "provider": AuthType.JINA, "credentials": jina_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_2, args2) + + db_session_with_containers.expire_all() + + result1 = ApiKeyAuthService.get_provider_auth_list(tenant_id_1) + result2 = ApiKeyAuthService.get_provider_auth_list(tenant_id_2) + + assert len(result1) == 1 + assert result1[0].tenant_id == tenant_id_1 + assert len(result2) == 1 + assert result2[0].tenant_id == tenant_id_2 + + def test_cross_tenant_access_prevention( + self, flask_app_with_containers, db_session_with_containers, tenant_id_2, category + ): + result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL) + + assert result is None + + def test_sensitive_data_protection(self): + credentials_with_secrets = { + "auth_type": "bearer", + "config": {"api_key": "super_secret_key_do_not_log", "secret": "another_secret"}, + } + + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, credentials_with_secrets) + factory_str = str(factory) + + assert "super_secret_key_do_not_log" not in factory_str + assert "another_secret" not in factory_str + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token", return_value="encrypted_key") + def test_concurrent_creation_safety( + self, + mock_encrypt, + mock_http, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + app = flask_app_with_containers + mock_http.return_value = self._create_success_response() + + results = [] + exceptions = [] + + def create_auth(): + try: + with app.app_context(): + thread_args = { + "category": category, + "provider": AuthType.FIRECRAWL, + "credentials": {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}}, + } + ApiKeyAuthService.create_provider_auth(tenant_id_1, thread_args) + results.append("success") + except Exception as e: + exceptions.append(e) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(create_auth) for _ in range(5)] + for future in futures: + future.result() + + assert len(results) == 5 + assert len(exceptions) == 0 + + @pytest.mark.parametrize( + "invalid_input", + [ + None, + {}, + {"auth_type": "bearer"}, + {"auth_type": "bearer", "config": {}}, + ], + ) + def test_invalid_input_boundary(self, invalid_input): + with pytest.raises((ValueError, KeyError, TypeError, AttributeError)): + ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input) + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_http_error_handling(self, mock_http, firecrawl_credentials): + mock_response = Mock() + mock_response.status_code = 401 + mock_response.text = '{"error": "Unauthorized"}' + mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized") + mock_http.return_value = mock_response + + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, firecrawl_credentials) + with pytest.raises((httpx.HTTPError, Exception)): + factory.validate_credentials() + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_network_failure_recovery( + self, + mock_http, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + mock_http.side_effect = httpx.RequestError("Network timeout") + + args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + + with pytest.raises(httpx.RequestError): + ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id_1).all() + assert len(bindings) == 0 + + @pytest.mark.parametrize( + ("provider", "credentials"), + [ + (AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "fc_key"}}), + (AuthType.JINA, {"auth_type": "bearer", "config": {"api_key": "jina_key"}}), + (AuthType.WATERCRAWL, {"auth_type": "x-api-key", "config": {"api_key": "wc_key"}}), + ], + ) + def test_all_providers_factory_creation(self, provider, credentials): + auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider) + assert auth_class is not None + + factory = ApiKeyAuthFactory(provider, credentials) + assert factory.auth is not None + + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_get_auth_credentials_returns_stored_credentials( + self, + mock_http, + mock_encrypt, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + mock_http.return_value = self._create_success_response() + mock_encrypt.return_value = "encrypted_key" + + args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_auth_credentials(tenant_id_1, category, AuthType.FIRECRAWL) + assert result is not None + assert result["config"]["api_key"] == "encrypted_key" + + def _create_success_response(self, status_code=200): + mock_response = Mock() + mock_response.status_code = status_code + mock_response.json.return_value = {"status": "success"} + mock_response.raise_for_status.return_value = None + return mock_response diff --git a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py index 3885137221..ce9f10e207 100644 --- a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py @@ -12,6 +12,7 @@ from uuid import uuid4 import pytest +from core.plugin.entities.plugin_daemon import CredentialType from models.tools import BuiltinToolProvider from services.plugin.plugin_parameter_service import PluginParameterService @@ -66,7 +67,7 @@ class TestGetDynamicSelectOptionsTool: provider="google", name="API KEY 1", encrypted_credentials=json.dumps({"api_key": "encrypted"}), - credential_type="api_key", + credential_type=CredentialType.API_KEY, ) db_session_with_containers.add(db_record) db_session_with_containers.commit() diff --git a/api/tests/unit_tests/services/plugin/test_plugin_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py similarity index 78% rename from api/tests/unit_tests/services/plugin/test_plugin_service.py rename to api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py index 09b9ab498b..0cdae572fb 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py @@ -8,15 +8,27 @@ verification, marketplace upgrade flows, and uninstall with credential cleanup. from __future__ import annotations from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest +from sqlalchemy import select from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.entities.plugin_daemon import PluginVerification +from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider from services.errors.plugin import PluginInstallationForbiddenError from services.feature_service import PluginInstallationScope from services.plugin.plugin_service import PluginService -from tests.unit_tests.services.plugin.conftest import make_features + + +def _make_features( + restrict_to_marketplace: bool = False, + scope: PluginInstallationScope = PluginInstallationScope.ALL, +) -> MagicMock: + features = MagicMock() + features.plugin_installation_permission.restrict_to_marketplace_only = restrict_to_marketplace + features.plugin_installation_permission.plugin_installation_scope = scope + return features class TestFetchLatestPluginVersion: @@ -80,14 +92,14 @@ class TestFetchLatestPluginVersion: class TestCheckMarketplaceOnlyPermission: @patch("services.plugin.plugin_service.FeatureService") def test_raises_when_restricted(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(restrict_to_marketplace=True) + mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=True) with pytest.raises(PluginInstallationForbiddenError): PluginService._check_marketplace_only_permission() @patch("services.plugin.plugin_service.FeatureService") def test_passes_when_not_restricted(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(restrict_to_marketplace=False) + mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=False) PluginService._check_marketplace_only_permission() # should not raise @@ -95,7 +107,7 @@ class TestCheckMarketplaceOnlyPermission: class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_official_only_allows_langgenius(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) verification = MagicMock() verification.authorized_category = PluginVerification.AuthorizedCategory.Langgenius @@ -103,14 +115,14 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_official_only_rejects_third_party(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) with pytest.raises(PluginInstallationForbiddenError): PluginService._check_plugin_installation_scope(None) @patch("services.plugin.plugin_service.FeatureService") def test_official_and_partners_allows_partner(self, mock_fs): - mock_fs.get_system_features.return_value = make_features( + mock_fs.get_system_features.return_value = _make_features( scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS ) verification = MagicMock() @@ -120,7 +132,7 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_official_and_partners_rejects_none(self, mock_fs): - mock_fs.get_system_features.return_value = make_features( + mock_fs.get_system_features.return_value = _make_features( scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS ) @@ -129,7 +141,7 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_none_scope_always_raises(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.NONE) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.NONE) verification = MagicMock() verification.authorized_category = PluginVerification.AuthorizedCategory.Langgenius @@ -138,7 +150,7 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_all_scope_passes_any(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.ALL) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.ALL) PluginService._check_plugin_installation_scope(None) # should not raise @@ -209,9 +221,9 @@ class TestUpgradePluginWithMarketplace: @patch("services.plugin.plugin_service.dify_config") def test_skips_download_when_already_installed(self, mock_config, mock_installer_cls, mock_fs, mock_marketplace): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value - installer.fetch_plugin_manifest.return_value = MagicMock() # no exception = already installed + installer.fetch_plugin_manifest.return_value = MagicMock() installer.upgrade_plugin.return_value = MagicMock() PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid") @@ -225,7 +237,7 @@ class TestUpgradePluginWithMarketplace: @patch("services.plugin.plugin_service.dify_config") def test_downloads_when_not_installed(self, mock_config, mock_installer_cls, mock_fs, mock_download): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.fetch_plugin_manifest.side_effect = RuntimeError("not found") mock_download.return_value = b"pkg-bytes" @@ -244,7 +256,7 @@ class TestUpgradePluginWithGithub: @patch("services.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.PluginInstaller") def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs): - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.upgrade_plugin.return_value = MagicMock() @@ -259,7 +271,7 @@ class TestUploadPkg: @patch("services.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.PluginInstaller") def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs): - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() upload_resp = MagicMock() upload_resp.verification = None mock_installer_cls.return_value.upload_pkg.return_value = upload_resp @@ -283,7 +295,7 @@ class TestInstallFromMarketplacePkg: @patch("services.plugin.plugin_service.dify_config") def test_downloads_when_not_cached(self, mock_config, mock_installer_cls, mock_fs, mock_download): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.fetch_plugin_manifest.side_effect = RuntimeError("not found") mock_download.return_value = b"pkg" @@ -298,14 +310,14 @@ class TestInstallFromMarketplacePkg: assert result == "task-id" installer.install_from_identifiers.assert_called_once() call_args = installer.install_from_identifiers.call_args[0] - assert call_args[1] == ["resolved-uid"] # uses response uid, not input + assert call_args[1] == ["resolved-uid"] @patch("services.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.PluginInstaller") @patch("services.plugin.plugin_service.dify_config") def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.fetch_plugin_manifest.return_value = MagicMock() decode_resp = MagicMock() @@ -317,7 +329,7 @@ class TestInstallFromMarketplacePkg: installer.install_from_identifiers.assert_called_once() call_args = installer.install_from_identifiers.call_args[0] - assert call_args[1] == ["uid-1"] # uses original uid + assert call_args[1] == ["uid-1"] class TestUninstall: @@ -332,26 +344,70 @@ class TestUninstall: assert result is True installer.uninstall.assert_called_once_with("t1", "install-1") - @patch("services.plugin.plugin_service.db") @patch("services.plugin.plugin_service.PluginInstaller") - def test_cleans_credentials_when_plugin_found(self, mock_installer_cls, mock_db): + def test_cleans_credentials_when_plugin_found( + self, mock_installer_cls, flask_app_with_containers, db_session_with_containers + ): + tenant_id = str(uuid4()) + plugin_id = "org/myplugin" + provider_name = f"{plugin_id}/model-provider" + + credential = ProviderCredential( + tenant_id=tenant_id, + provider_name=provider_name, + credential_name="default", + encrypted_config="{}", + ) + db_session_with_containers.add(credential) + db_session_with_containers.flush() + credential_id = credential.id + + provider = Provider( + tenant_id=tenant_id, + provider_name=provider_name, + credential_id=credential_id, + ) + db_session_with_containers.add(provider) + db_session_with_containers.flush() + provider_id = provider.id + + pref = TenantPreferredModelProvider( + tenant_id=tenant_id, + provider_name=provider_name, + preferred_provider_type="custom", + ) + db_session_with_containers.add(pref) + db_session_with_containers.commit() + plugin = MagicMock() plugin.installation_id = "install-1" - plugin.plugin_id = "org/myplugin" + plugin.plugin_id = plugin_id installer = mock_installer_cls.return_value installer.list_plugins.return_value = [plugin] installer.uninstall.return_value = True - # Mock Session context manager - mock_session = MagicMock() - mock_db.engine = MagicMock() - mock_session.scalars.return_value.all.return_value = [] # no credentials found - - with patch("services.plugin.plugin_service.Session") as mock_session_cls: - mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) - - result = PluginService.uninstall("t1", "install-1") + with patch("services.plugin.plugin_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + result = PluginService.uninstall(tenant_id, "install-1") assert result is True installer.uninstall.assert_called_once() + + db_session_with_containers.expire_all() + + remaining_creds = db_session_with_containers.scalars( + select(ProviderCredential).where(ProviderCredential.id == credential_id) + ).all() + assert len(remaining_creds) == 0 + + updated_provider = db_session_with_containers.get(Provider, provider_id) + assert updated_provider is not None + assert updated_provider.credential_id is None + + remaining_prefs = db_session_with_containers.scalars( + select(TenantPreferredModelProvider).where( + TenantPreferredModelProvider.tenant_id == tenant_id, + TenantPreferredModelProvider.provider_name == provider_name, + ) + ).all() + assert len(remaining_prefs) == 0 diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 00a2f9a59f..4f3c0e4200 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -842,6 +842,7 @@ class TestAgentService: conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) from graphon.file import FileTransferMethod, FileType + from models.enums import CreatorUserRole # Add files to message diff --git a/api/tests/unit_tests/services/test_api_token_service.py b/api/tests/test_containers_integration_tests/services/test_api_token_service.py similarity index 71% rename from api/tests/unit_tests/services/test_api_token_service.py rename to api/tests/test_containers_integration_tests/services/test_api_token_service.py index ad4de93b25..a2028d3ed3 100644 --- a/api/tests/unit_tests/services/test_api_token_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_token_service.py @@ -1,80 +1,63 @@ +from __future__ import annotations + from datetime import datetime -from types import SimpleNamespace from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest from werkzeug.exceptions import Unauthorized import services.api_token_service as api_token_service_module +from models.model import ApiToken from services.api_token_service import ApiTokenCache, CachedApiToken -@pytest.fixture -def mock_db_session(): - """Fixture providing common DB session mocking for query_token_from_db tests.""" - fake_engine = MagicMock() - - session = MagicMock() - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = None - - with ( - patch.object(api_token_service_module, "db", new=SimpleNamespace(engine=fake_engine)), - patch.object(api_token_service_module, "Session", return_value=session_context) as mock_session_class, - patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set, - patch.object(api_token_service_module, "record_token_usage") as mock_record_usage, - ): - yield { - "session": session, - "mock_session_class": mock_session_class, - "mock_cache_set": mock_cache_set, - "mock_record_usage": mock_record_usage, - "fake_engine": fake_engine, - } - - class TestQueryTokenFromDb: - def test_should_return_api_token_and_cache_when_token_exists(self, mock_db_session): - """Test DB lookup success path caches token and records usage.""" - # Arrange - auth_token = "token-123" - scope = "app" - api_token = MagicMock() + def test_should_return_api_token_and_cache_when_token_exists( + self, flask_app_with_containers, db_session_with_containers + ): + tenant_id = str(uuid4()) + app_id = str(uuid4()) + token_value = f"app-test-{uuid4()}" - mock_db_session["session"].scalar.return_value = api_token + api_token = ApiToken() + api_token.id = str(uuid4()) + api_token.app_id = app_id + api_token.tenant_id = tenant_id + api_token.type = "app" + api_token.token = token_value + db_session_with_containers.add(api_token) + db_session_with_containers.commit() - # Act - result = api_token_service_module.query_token_from_db(auth_token, scope) + with ( + patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set, + patch.object(api_token_service_module, "record_token_usage") as mock_record_usage, + ): + result = api_token_service_module.query_token_from_db(token_value, "app") - # Assert - assert result == api_token - mock_db_session["mock_session_class"].assert_called_once_with( - mock_db_session["fake_engine"], expire_on_commit=False - ) - mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, api_token) - mock_db_session["mock_record_usage"].assert_called_once_with(auth_token, scope) + assert result.id == api_token.id + assert result.token == token_value + mock_cache_set.assert_called_once() + mock_record_usage.assert_called_once_with(token_value, "app") - def test_should_cache_null_and_raise_unauthorized_when_token_not_found(self, mock_db_session): - """Test DB lookup miss path caches null marker and raises Unauthorized.""" - # Arrange - auth_token = "missing-token" - scope = "app" + def test_should_cache_null_and_raise_unauthorized_when_token_not_found( + self, flask_app_with_containers, db_session_with_containers + ): + with ( + patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set, + patch.object(api_token_service_module, "record_token_usage") as mock_record_usage, + ): + with pytest.raises(Unauthorized, match="Access token is invalid"): + api_token_service_module.query_token_from_db(f"missing-{uuid4()}", "app") - mock_db_session["session"].scalar.return_value = None - - # Act / Assert - with pytest.raises(Unauthorized, match="Access token is invalid"): - api_token_service_module.query_token_from_db(auth_token, scope) - - mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, None) - mock_db_session["mock_record_usage"].assert_not_called() + mock_cache_set.assert_called_once() + call_args = mock_cache_set.call_args[0] + assert call_args[2] is None # cached None + mock_record_usage.assert_not_called() class TestRecordTokenUsage: def test_should_write_active_key_with_iso_timestamp_and_ttl(self): - """Test record_token_usage writes usage timestamp with one-hour TTL.""" - # Arrange auth_token = "token-123" scope = "dataset" fixed_time = datetime(2026, 2, 24, 12, 0, 0) @@ -84,26 +67,18 @@ class TestRecordTokenUsage: patch.object(api_token_service_module, "naive_utc_now", return_value=fixed_time), patch.object(api_token_service_module, "redis_client") as mock_redis, ): - # Act api_token_service_module.record_token_usage(auth_token, scope) - # Assert mock_redis.set.assert_called_once_with(expected_key, fixed_time.isoformat(), ex=3600) def test_should_not_raise_when_redis_write_fails(self): - """Test record_token_usage swallows Redis errors.""" - # Arrange with patch.object(api_token_service_module, "redis_client") as mock_redis: mock_redis.set.side_effect = Exception("redis unavailable") - - # Act / Assert api_token_service_module.record_token_usage("token-123", "app") class TestFetchTokenWithSingleFlight: def test_should_return_cached_token_when_lock_acquired_and_cache_filled(self): - """Test single-flight returns cache when another request already populated it.""" - # Arrange auth_token = "token-123" scope = "app" cached_token = CachedApiToken( @@ -115,39 +90,26 @@ class TestFetchTokenWithSingleFlight: last_used_at=None, created_at=None, ) - lock = MagicMock() lock.acquire.return_value = True with ( patch.object(api_token_service_module, "redis_client") as mock_redis, - patch.object(api_token_service_module.ApiTokenCache, "get", return_value=cached_token) as mock_cache_get, + patch.object(api_token_service_module.ApiTokenCache, "get", return_value=cached_token), patch.object(api_token_service_module, "query_token_from_db") as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == cached_token - mock_redis.lock.assert_called_once_with( - f"api_token_query_lock:{scope}:{auth_token}", - timeout=10, - blocking_timeout=5, - ) lock.acquire.assert_called_once_with(blocking=True) lock.release.assert_called_once() - mock_cache_get.assert_called_once_with(auth_token, scope) mock_query_db.assert_not_called() def test_should_query_db_when_lock_acquired_and_cache_missed(self): - """Test single-flight queries DB when cache remains empty after lock acquisition.""" - # Arrange auth_token = "token-123" scope = "app" db_token = MagicMock() - lock = MagicMock() lock.acquire.return_value = True @@ -157,22 +119,16 @@ class TestFetchTokenWithSingleFlight: patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == db_token mock_query_db.assert_called_once_with(auth_token, scope) lock.release.assert_called_once() def test_should_query_db_directly_when_lock_not_acquired(self): - """Test lock timeout branch falls back to direct DB query.""" - # Arrange auth_token = "token-123" scope = "app" db_token = MagicMock() - lock = MagicMock() lock.acquire.return_value = False @@ -182,19 +138,14 @@ class TestFetchTokenWithSingleFlight: patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == db_token mock_cache_get.assert_not_called() mock_query_db.assert_called_once_with(auth_token, scope) lock.release.assert_not_called() def test_should_reraise_unauthorized_from_db_query(self): - """Test Unauthorized from DB query is propagated unchanged.""" - # Arrange auth_token = "token-123" scope = "app" lock = MagicMock() @@ -210,20 +161,15 @@ class TestFetchTokenWithSingleFlight: ), ): mock_redis.lock.return_value = lock - - # Act / Assert with pytest.raises(Unauthorized, match="Access token is invalid"): api_token_service_module.fetch_token_with_single_flight(auth_token, scope) lock.release.assert_called_once() def test_should_fallback_to_db_query_when_lock_raises_exception(self): - """Test Redis lock errors fall back to direct DB query.""" - # Arrange auth_token = "token-123" scope = "app" db_token = MagicMock() - lock = MagicMock() lock.acquire.side_effect = RuntimeError("redis lock error") @@ -232,11 +178,8 @@ class TestFetchTokenWithSingleFlight: patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == db_token mock_query_db.assert_called_once_with(auth_token, scope) @@ -244,8 +187,6 @@ class TestFetchTokenWithSingleFlight: class TestApiTokenCacheTenantBranches: @patch("services.api_token_service.redis_client") def test_delete_with_scope_should_remove_from_tenant_index_when_tenant_found(self, mock_redis): - """Test scoped delete removes cache key and tenant index membership.""" - # Arrange token = "token-123" scope = "app" cache_key = ApiTokenCache._make_cache_key(token, scope) @@ -261,18 +202,14 @@ class TestApiTokenCacheTenantBranches: mock_redis.get.return_value = cached_token.model_dump_json().encode("utf-8") with patch.object(ApiTokenCache, "_remove_from_tenant_index") as mock_remove_index: - # Act result = ApiTokenCache.delete(token, scope) - # Assert assert result is True mock_redis.delete.assert_called_once_with(cache_key) mock_remove_index.assert_called_once_with("tenant-1", cache_key) @patch("services.api_token_service.redis_client") def test_invalidate_by_tenant_should_delete_all_indexed_cache_keys(self, mock_redis): - """Test tenant invalidation deletes indexed cache entries and index key.""" - # Arrange tenant_id = "tenant-1" index_key = ApiTokenCache._make_tenant_index_key(tenant_id) mock_redis.smembers.return_value = { @@ -280,10 +217,8 @@ class TestApiTokenCacheTenantBranches: b"api_token:any:token-2", } - # Act result = ApiTokenCache.invalidate_by_tenant(tenant_id) - # Assert assert result is True mock_redis.smembers.assert_called_once_with(index_key) mock_redis.delete.assert_any_call("api_token:app:token-1") @@ -293,7 +228,6 @@ class TestApiTokenCacheTenantBranches: class TestApiTokenCacheCoreBranches: def test_cached_api_token_repr_should_include_id_and_type(self): - """Test CachedApiToken __repr__ includes key identity fields.""" token = CachedApiToken( id="id-123", app_id="app-123", @@ -303,11 +237,9 @@ class TestApiTokenCacheCoreBranches: last_used_at=None, created_at=None, ) - assert repr(token) == "" def test_serialize_token_should_handle_cached_api_token_instances(self): - """Test serialization path when input is already a CachedApiToken.""" token = CachedApiToken( id="id-123", app_id="app-123", @@ -317,35 +249,25 @@ class TestApiTokenCacheCoreBranches: last_used_at=None, created_at=None, ) - serialized = ApiTokenCache._serialize_token(token) - assert isinstance(serialized, bytes) assert b'"id":"id-123"' in serialized - assert b'"token":"token-123"' in serialized def test_deserialize_token_should_return_none_for_null_markers(self): - """Test null cache marker deserializes to None.""" assert ApiTokenCache._deserialize_token("null") is None assert ApiTokenCache._deserialize_token(b"null") is None def test_deserialize_token_should_return_none_for_invalid_payload(self): - """Test invalid serialized payload returns None.""" assert ApiTokenCache._deserialize_token("not-json") is None @patch("services.api_token_service.redis_client") def test_get_should_return_none_on_cache_miss(self, mock_redis): - """Test cache miss branch in ApiTokenCache.get.""" mock_redis.get.return_value = None - result = ApiTokenCache.get("token-123", "app") - assert result is None - mock_redis.get.assert_called_once_with("api_token:app:token-123") @patch("services.api_token_service.redis_client") def test_get_should_deserialize_cached_payload_on_cache_hit(self, mock_redis): - """Test cache hit branch in ApiTokenCache.get.""" token = CachedApiToken( id="id-123", app_id="app-123", @@ -356,48 +278,34 @@ class TestApiTokenCacheCoreBranches: created_at=None, ) mock_redis.get.return_value = token.model_dump_json().encode("utf-8") - result = ApiTokenCache.get("token-123", "app") - assert isinstance(result, CachedApiToken) assert result.id == "id-123" @patch("services.api_token_service.redis_client") def test_add_to_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis): - """Test tenant index update exits early for missing tenant id.""" ApiTokenCache._add_to_tenant_index(None, "api_token:app:token-123") - mock_redis.sadd.assert_not_called() - mock_redis.expire.assert_not_called() @patch("services.api_token_service.redis_client") def test_add_to_tenant_index_should_swallow_index_update_errors(self, mock_redis): - """Test tenant index update handles Redis write errors gracefully.""" mock_redis.sadd.side_effect = Exception("redis down") - ApiTokenCache._add_to_tenant_index("tenant-123", "api_token:app:token-123") - mock_redis.sadd.assert_called_once() @patch("services.api_token_service.redis_client") def test_remove_from_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis): - """Test tenant index removal exits early for missing tenant id.""" ApiTokenCache._remove_from_tenant_index(None, "api_token:app:token-123") - mock_redis.srem.assert_not_called() @patch("services.api_token_service.redis_client") def test_remove_from_tenant_index_should_swallow_redis_errors(self, mock_redis): - """Test tenant index removal handles Redis errors gracefully.""" mock_redis.srem.side_effect = Exception("redis down") - ApiTokenCache._remove_from_tenant_index("tenant-123", "api_token:app:token-123") - mock_redis.srem.assert_called_once() @patch("services.api_token_service.redis_client") def test_set_should_return_false_when_cache_write_raises_exception(self, mock_redis): - """Test set returns False when Redis setex fails.""" mock_redis.setex.side_effect = Exception("redis write failed") api_token = MagicMock() api_token.id = "id-123" @@ -407,60 +315,41 @@ class TestApiTokenCacheCoreBranches: api_token.token = "token-123" api_token.last_used_at = None api_token.created_at = None - result = ApiTokenCache.set("token-123", "app", api_token) - assert result is False @patch("services.api_token_service.redis_client") def test_delete_without_scope_should_return_false_when_scan_fails(self, mock_redis): - """Test delete(scope=None) returns False when scan_iter raises.""" mock_redis.scan_iter.side_effect = Exception("scan failed") - result = ApiTokenCache.delete("token-123", None) - assert result is False @patch("services.api_token_service.redis_client") def test_delete_with_scope_should_continue_when_tenant_lookup_raises(self, mock_redis): - """Test scoped delete still succeeds when tenant lookup from cache fails.""" token = "token-123" scope = "app" cache_key = ApiTokenCache._make_cache_key(token, scope) mock_redis.get.side_effect = Exception("get failed") - result = ApiTokenCache.delete(token, scope) - assert result is True mock_redis.delete.assert_called_once_with(cache_key) @patch("services.api_token_service.redis_client") def test_delete_with_scope_should_return_false_when_delete_raises(self, mock_redis): - """Test scoped delete returns False when delete operation fails.""" - token = "token-123" - scope = "app" mock_redis.get.return_value = None mock_redis.delete.side_effect = Exception("delete failed") - - result = ApiTokenCache.delete(token, scope) - + result = ApiTokenCache.delete("token-123", "app") assert result is False @patch("services.api_token_service.redis_client") def test_invalidate_by_tenant_should_return_true_when_index_not_found(self, mock_redis): - """Test tenant invalidation returns True when tenant index is empty.""" mock_redis.smembers.return_value = set() - result = ApiTokenCache.invalidate_by_tenant("tenant-123") - assert result is True mock_redis.delete.assert_not_called() @patch("services.api_token_service.redis_client") def test_invalidate_by_tenant_should_return_false_when_redis_raises(self, mock_redis): - """Test tenant invalidation returns False when Redis operation fails.""" mock_redis.smembers.side_effect = Exception("redis failed") - result = ApiTokenCache.invalidate_by_tenant("tenant-123") - assert result is False diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py index 02ab3f8314..fb0adbbcc2 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -3,10 +3,10 @@ from uuid import uuid4 import pytest +from graphon.variables import StringVariable from sqlalchemy.orm import sessionmaker from extensions.ext_database import db -from graphon.variables import StringVariable from models.workflow import ConversationVariable from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index 0de3c64c4f..f9bfa570cb 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -9,11 +9,11 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index 883c3c3feb..a814466e14 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -2,10 +2,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType -from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings from models.enums import DataSourceType diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index fe426ae516..c8f04e9215 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -5,9 +5,9 @@ Testcontainers integration tests for archived workflow run deletion service. from datetime import UTC, datetime, timedelta from uuid import uuid4 +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import select -from graphon.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 18c5320d0a..c46b8fba0b 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -3,6 +3,8 @@ import uuid from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from core.workflow.human_input_compat import ( EmailDeliveryConfig, @@ -10,8 +12,6 @@ from core.workflow.human_input_compat import ( EmailRecipients, ExternalRecipient, ) -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.human_input.entities import HumanInputNodeData from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, AppMode from models.workflow import Workflow, WorkflowType diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py similarity index 85% rename from api/tests/unit_tests/services/test_human_input_delivery_test_service.py rename to api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index 2c0f561860..0f252515f7 100644 --- a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -1,7 +1,11 @@ +from __future__ import annotations + from types import SimpleNamespace from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest +from graphon.runtime import VariablePool from sqlalchemy.engine import Engine from configs import dify_config @@ -12,7 +16,7 @@ from core.workflow.human_input_compat import ( ExternalRecipient, MemberRecipient, ) -from graphon.runtime import VariablePool +from models.account import Account, TenantAccountJoin from services import human_input_delivery_test_service as service_module from services.human_input_delivery_test_service import ( DeliveryTestContext, @@ -28,13 +32,6 @@ from services.human_input_delivery_test_service import ( ) -@pytest.fixture -def mock_db(monkeypatch): - mock_db = MagicMock() - monkeypatch.setattr(service_module, "db", mock_db) - return mock_db - - def _make_valid_email_config(): return EmailDeliveryConfig( recipients=EmailRecipients(include_bound_group=False, items=[]), @@ -91,7 +88,7 @@ class TestDeliveryTestRegistry: with pytest.raises(DeliveryTestUnsupportedError, match="Delivery method does not support test send."): registry.dispatch(context=context, method=method) - def test_default(self, mock_db): + def test_default(self, flask_app_with_containers, db_session_with_containers): registry = DeliveryTestRegistry.default() assert len(registry._handlers) == 1 assert isinstance(registry._handlers[0], EmailDeliveryTestHandler) @@ -250,10 +247,8 @@ class TestEmailDeliveryTestHandler: _, kwargs = mock_mail_send.call_args assert kwargs["subject"] == "Notice BCC:test@example.com" - def test_resolve_recipients(self): + def test_resolve_recipients_external(self): handler = EmailDeliveryTestHandler(session_factory=MagicMock()) - - # Test Case 1: External Recipient method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( @@ -265,18 +260,43 @@ class TestEmailDeliveryTestHandler: ) assert handler._resolve_recipients(tenant_id="t1", method=method) == ["ext@example.com"] - # Test Case 2: Member Recipient + def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + account = Account(name="Test User", email="member@example.com") + db_session_with_containers.add(account) + db_session_with_containers.commit() + + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account.id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + from extensions.ext_database import db + + handler = EmailDeliveryTestHandler(session_factory=db.engine) method = EmailDeliveryMethod( config=EmailDeliveryConfig( - recipients=EmailRecipients(items=[MemberRecipient(reference_id="u1")], include_bound_group=False), + recipients=EmailRecipients(items=[MemberRecipient(reference_id=account.id)], include_bound_group=False), subject="", body="", ) ) - handler._query_workspace_member_emails = MagicMock(return_value={"u1": "u1@example.com"}) - assert handler._resolve_recipients(tenant_id="t1", method=method) == ["u1@example.com"] + assert handler._resolve_recipients(tenant_id=tenant_id, method=method) == ["member@example.com"] - # Test Case 3: Whole Workspace + def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + account1 = Account(name="User 1", email=f"u1-{uuid4()}@example.com") + account2 = Account(name="User 2", email=f"u2-{uuid4()}@example.com") + db_session_with_containers.add_all([account1, account2]) + db_session_with_containers.commit() + + for acc in [account1, account2]: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=acc.id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + from extensions.ext_database import db + + handler = EmailDeliveryTestHandler(session_factory=db.engine) method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients(items=[], include_bound_group=True), @@ -284,34 +304,13 @@ class TestEmailDeliveryTestHandler: body="", ) ) - handler._query_workspace_member_emails = MagicMock( - return_value={"u1": "u1@example.com", "u2": "u2@example.com"} - ) - recipients = handler._resolve_recipients(tenant_id="t1", method=method) - assert set(recipients) == {"u1@example.com", "u2@example.com"} + recipients = handler._resolve_recipients(tenant_id=tenant_id, method=method) + assert set(recipients) == {account1.email, account2.email} - def test_query_workspace_member_emails(self): - mock_session = MagicMock() - mock_session_factory = MagicMock(return_value=mock_session) - mock_session.__enter__.return_value = mock_session - - handler = EmailDeliveryTestHandler(session_factory=mock_session_factory) - - # Empty user_ids + def test_query_workspace_member_emails_empty_ids(self): + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) assert handler._query_workspace_member_emails(tenant_id="t1", user_ids=[]) == {} - # user_ids is None (all) - mock_execute = MagicMock() - mock_session.execute.return_value = mock_execute - mock_execute.all.return_value = [("u1", "u1@example.com")] - - result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=None) - assert result == {"u1": "u1@example.com"} - - # user_ids with values - result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=["u1"]) - assert result == {"u1": "u1@example.com"} - def test_build_substitutions(self): context = DeliveryTestContext( tenant_id="t1", @@ -333,7 +332,6 @@ class TestEmailDeliveryTestHandler: assert subs["form_token"] == "token123" assert "form/token123" in subs["form_link"] - # Without matching recipient subs_no_match = EmailDeliveryTestHandler._build_substitutions( context=context, recipient_email="other@example.com" ) diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index c0c1c25f1e..2340dd2a03 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -1,16 +1,18 @@ +from __future__ import annotations + import datetime import json import uuid from decimal import Decimal -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from faker import Faker +from graphon.file import FileType from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client -from graphon.file.enums import FileType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ( ConversationFromSource, @@ -1169,3 +1171,66 @@ class TestMessagesCleanServiceIntegration: # Verify all messages were deleted assert db_session_with_containers.query(Message).where(Message.id.in_(msg_ids)).count() == 0 + + def test_from_time_range_validation(self): + """Test that from_time_range raises ValueError for invalid inputs.""" + policy = MagicMock(spec=BillingDisabledPolicy) + now = datetime.datetime.now() + + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + MessagesCleanService.from_time_range(policy, now, now) + + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_time_range(policy, now - datetime.timedelta(days=1), now, batch_size=0) + + def test_from_time_range_success(self): + """Test that from_time_range creates a service with correct parameters.""" + policy = MagicMock(spec=BillingDisabledPolicy) + start = datetime.datetime(2024, 1, 1) + end = datetime.datetime(2024, 2, 1) + + service = MessagesCleanService.from_time_range(policy, start, end) + assert service._start_from == start + assert service._end_before == end + + def test_from_days_validation(self): + """Test that from_days raises ValueError for invalid inputs.""" + policy = MagicMock(spec=BillingDisabledPolicy) + + with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"): + MessagesCleanService.from_days(policy, days=-1) + + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_days(policy, days=30, batch_size=0) + + def test_from_days_success(self): + """Test that from_days creates a service with correct parameters.""" + policy = MagicMock(spec=BillingDisabledPolicy) + + with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: + fixed_now = datetime.datetime(2024, 6, 1) + mock_now.return_value = fixed_now + + service = MessagesCleanService.from_days(policy, days=10) + assert service._start_from is None + assert service._end_before == fixed_now - datetime.timedelta(days=10) + + def test_batch_delete_message_relations_empty(self, db_session_with_containers: Session): + """Test that batch_delete_message_relations with empty list does nothing.""" + # Get execute call count before + MessagesCleanService._batch_delete_message_relations(db_session_with_containers, []) + # No exception means success — empty list is a no-op + + def test_run_calls_clean_messages(self): + """Test that run() delegates to _clean_messages_by_time_range.""" + policy = MagicMock(spec=BillingDisabledPolicy) + service = MessagesCleanService( + policy=policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + with patch.object(service, "_clean_messages_by_time_range") as mock_clean: + mock_clean.return_value = {"total_deleted": 5} + result = service.run() + assert result == {"total_deleted": 5} + mock_clean.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 8955a3b5f2..ba926bf675 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService @@ -405,10 +405,11 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock models - from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.provider_entities import ProviderEntity + from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity + # Create real model objects instead of mocks provider_entity_1 = SimpleModelProviderEntity( ProviderEntity( @@ -643,9 +644,10 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock default model response - from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from graphon.model_runtime.entities.common_entities import I18nObject + from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity + mock_default_model = DefaultModelEntity( model="gpt-3.5-turbo", model_type=ModelType.LLM, diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 2a18345c87..749c6fff5b 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -8,9 +8,9 @@ from unittest.mock import patch import pytest from faker import Faker +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session -from graphon.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLogCreatedFrom diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 86cf2327c7..0c281c8c33 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,9 +1,9 @@ import pytest from faker import Faker +from graphon.variables.segments import StringSegment from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from graphon.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index ee7b68e6aa..b5ce8a53de 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -555,6 +555,124 @@ class TestWorkflowService: assert len(result_workflows) == 2 assert all(wf.marked_name for wf in result_workflows) + def test_get_all_published_workflow_no_workflow_id(self, db_session_with_containers: Session): + """Test that an app with no workflow_id returns empty results.""" + # Arrange + fake = Faker() + app = self._create_test_app(db_session_with_containers, fake) + app.workflow_id = None + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + # Act + result_workflows, has_more = workflow_service.get_all_published_workflow( + session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=None + ) + + # Assert + assert result_workflows == [] + assert has_more is False + + def test_get_all_published_workflow_basic(self, db_session_with_containers: Session): + """Test basic retrieval of published workflows.""" + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + app = self._create_test_app(db_session_with_containers, fake) + + workflow1 = self._create_test_workflow(db_session_with_containers, app, account, fake) + workflow1.version = "2024.01.01.001" + workflow2 = self._create_test_workflow(db_session_with_containers, app, account, fake) + workflow2.version = "2024.01.02.001" + + app.workflow_id = workflow1.id + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + # Act + result_workflows, has_more = workflow_service.get_all_published_workflow( + session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=None + ) + + # Assert + assert len(result_workflows) == 2 + assert has_more is False + + def test_get_all_published_workflow_combined_filters(self, db_session_with_containers: Session): + """Test combined user_id and named_only filters.""" + # Arrange + fake = Faker() + account1 = self._create_test_account(db_session_with_containers, fake) + account2 = self._create_test_account(db_session_with_containers, fake) + app = self._create_test_app(db_session_with_containers, fake) + + # account1 named + wf1 = self._create_test_workflow(db_session_with_containers, app, account1, fake) + wf1.version = "2024.01.01.001" + wf1.marked_name = "Named by user1" + wf1.created_by = account1.id + + # account1 unnamed + wf2 = self._create_test_workflow(db_session_with_containers, app, account1, fake) + wf2.version = "2024.01.02.001" + wf2.marked_name = "" + wf2.created_by = account1.id + + # account2 named + wf3 = self._create_test_workflow(db_session_with_containers, app, account2, fake) + wf3.version = "2024.01.03.001" + wf3.marked_name = "Named by user2" + wf3.created_by = account2.id + + app.workflow_id = wf1.id + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + # Act - Filter by account1 + named_only + result_workflows, has_more = workflow_service.get_all_published_workflow( + session=db_session_with_containers, + app_model=app, + page=1, + limit=10, + user_id=account1.id, + named_only=True, + ) + + # Assert - Only wf1 matches (account1 + named) + assert len(result_workflows) == 1 + assert result_workflows[0].marked_name == "Named by user1" + assert result_workflows[0].created_by == account1.id + + def test_get_all_published_workflow_empty_result(self, db_session_with_containers: Session): + """Test that querying with no matching workflows returns empty.""" + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + app = self._create_test_app(db_session_with_containers, fake) + + # Create a draft workflow (no version set = draft) + workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) + app.workflow_id = workflow.id + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + # Act - Filter by a user that has no workflows + result_workflows, has_more = workflow_service.get_all_published_workflow( + session=db_session_with_containers, + app_model=app, + page=1, + limit=10, + user_id="00000000-0000-0000-0000-000000000000", + ) + + # Assert + assert result_workflows == [] + assert has_more is False + def test_sync_draft_workflow_create_new(self, db_session_with_containers: Session): """ Test creating a new draft workflow through sync operation. diff --git a/api/tests/test_containers_integration_tests/services/test_workspace_service.py b/api/tests/test_containers_integration_tests/services/test_workspace_service.py index 92dec24c7d..4e89d906f1 100644 --- a/api/tests/test_containers_integration_tests/services/test_workspace_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workspace_service.py @@ -1,4 +1,6 @@ -from unittest.mock import patch +from __future__ import annotations + +from unittest.mock import MagicMock, patch import pytest from faker import Faker @@ -534,3 +536,283 @@ class TestWorkspaceService: # Verify database state db_session_with_containers.refresh(tenant) assert tenant.id is not None + + def test_get_tenant_info_should_raise_assertion_when_join_missing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """TenantAccountJoin must exist; missing join should raise AssertionError.""" + fake = Faker() + account = Account(email=fake.email(), name=fake.name(), interface_language="en-US", status="active") + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant = Tenant(name=fake.company(), status="normal", plan="basic") + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + # No TenantAccountJoin created + with patch("services.workspace_service.current_user", account): + with pytest.raises(AssertionError, match="TenantAccountJoin not found"): + WorkspaceService.get_tenant_info(tenant) + + def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """replace_webapp_logo should be None when custom_config_dict does not have the key.""" + import json + + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + tenant.custom_config = json.dumps({}) + db_session_with_containers.commit() + + mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True + mock_external_service_dependencies["tenant_service"].has_roles.return_value = True + + with patch("services.workspace_service.current_user", account): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["custom_config"]["replace_webapp_logo"] is None + + def test_get_tenant_info_should_use_files_url_for_logo_url( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """The logo URL should use dify_config.FILES_URL as the base.""" + import json + + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + tenant.custom_config = json.dumps({"replace_webapp_logo": True}) + db_session_with_containers.commit() + + custom_base = "https://cdn.mycompany.io" + mock_external_service_dependencies["dify_config"].FILES_URL = custom_base + mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True + mock_external_service_dependencies["tenant_service"].has_roles.return_value = True + + with patch("services.workspace_service.current_user", account): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base) + + def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "SELF_HOSTED" + mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = False + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + with patch("services.workspace_service.current_user", account): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert "next_credit_reset_date" not in result + assert "trial_credits" not in result + assert "trial_credits_used" not in result + + def test_get_tenant_info_cloud_credit_reset_date( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """next_credit_reset_date should be present in CLOUD edition.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", return_value=None), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["next_credit_reset_date"] == "2025-02-01" + + def test_get_tenant_info_cloud_paid_pool_not_full( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """trial_credits come from paid pool when plan is not sandbox and pool is not full.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=1000, quota_used=200) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", return_value=paid_pool), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 1000 + assert result["trial_credits_used"] == 200 + + def test_get_tenant_info_cloud_paid_pool_unlimited( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """quota_limit == -1 means unlimited; service should use paid pool.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=-1, quota_used=999) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, None]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == -1 + assert result["trial_credits_used"] == 999 + + def test_get_tenant_info_cloud_fall_back_to_trial_when_paid_full( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When paid pool is exhausted, switch to trial pool.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=500, quota_used=500) + trial_pool = MagicMock(quota_limit=100, quota_used=10) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, trial_pool]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 100 + assert result["trial_credits_used"] == 10 + + def test_get_tenant_info_cloud_fall_back_to_trial_when_paid_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When paid_pool is None, fall back to trial pool.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + trial_pool = MagicMock(quota_limit=50, quota_used=5) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[None, trial_pool]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 50 + assert result["trial_credits_used"] == 5 + + def test_get_tenant_info_cloud_sandbox_uses_trial_pool( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When plan is SANDBOX, skip paid pool and use trial pool.""" + from enums.cloud_plan import CloudPlan + + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = CloudPlan.SANDBOX + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=1000, quota_used=0) + trial_pool = MagicMock(quota_limit=200, quota_used=20) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, trial_pool]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 200 + assert result["trial_credits_used"] == 20 + + def test_get_tenant_info_cloud_both_pools_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When both paid and trial pools are absent, trial_credits should not be set.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[None, None]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert "trial_credits" not in result + assert "trial_credits_used" not in result diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index ce5c2bd162..ce2fd2eeb1 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -5,6 +5,9 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity, VariableEntityType from sqlalchemy.orm import Session from core.app.app_config.entities import ( @@ -18,9 +21,6 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models import Account, Tenant from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index 4dab895135..7c43bf676b 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -1,10 +1,10 @@ from datetime import datetime, timedelta from uuid import uuid4 +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index d341c5ce99..a16f3ff773 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -3,6 +3,9 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.app_config.entities import WorkflowUIBasedAppConfig @@ -17,9 +20,6 @@ from core.workflow.human_input_compat import ( MemberRecipient, ) from extensions.ext_storage import storage -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.entities import HumanInputNodeData -from graphon.runtime import GraphRuntimeState, VariablePool from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index 9a7507a2f9..96cf9cebf5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,11 +2,11 @@ import uuid from unittest.mock import ANY, call, patch import pytest +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from models import Tenant from models.enums import CreatorUserRole diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index b9f513a6d0..159ab51304 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -24,12 +24,12 @@ from dataclasses import dataclass from datetime import timedelta import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import delete, select from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 8854ef5e04..7539bae685 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -10,6 +10,7 @@ from typing import Any import pytest from flask import Flask, Response from flask.testing import FlaskClient +from graphon.enums import BuiltinNodeTypes from sqlalchemy.orm import Session from configs import dify_config @@ -23,7 +24,6 @@ from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key -from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus diff --git a/api/tests/unit_tests/controllers/console/app/test_audio.py b/api/tests/unit_tests/controllers/console/app/test_audio.py index 2d218dac7e..c52bc02420 100644 --- a/api/tests/unit_tests/controllers/console/app/test_audio.py +++ b/api/tests/unit_tests/controllers/console/app/test_audio.py @@ -4,6 +4,7 @@ import io from types import SimpleNamespace import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -20,7 +21,6 @@ from controllers.console.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index 341efc05ca..3607636880 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -5,12 +5,11 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from graphon.file import File, FileTransferMethod, FileType from werkzeug.exceptions import HTTPException, NotFound from controllers.console.app import workflow as workflow_module from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File def _unwrap(func): diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index c4a8148446..e11102acb1 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -6,14 +6,14 @@ from unittest.mock import Mock import pytest from flask import Flask - -from controllers.console import wraps as console_wraps -from controllers.console.app import workflow_run as workflow_run_module -from controllers.web.error import NotFoundError from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import WorkflowExecutionStatus from graphon.nodes.human_input.entities import FormInput, UserAction from graphon.nodes.human_input.enums import FormInputType + +from controllers.console import wraps as console_wraps +from controllers.console.app import workflow_run as workflow_run_module +from controllers.web.error import NotFoundError from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index 559b5fea09..740da1f1df 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch import pytest from flask_restx import marshal +from graphon.variables.types import SegmentType from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_FIELDS, @@ -15,7 +16,6 @@ from controllers.console.app.workflow_draft_variable import ( ) from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment -from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile @@ -310,8 +310,7 @@ def test_workflow_node_variables_fields(): def test_workflow_file_variable_with_signed_url(): """Test that File type variables include signed URLs in API responses.""" - from graphon.file.enums import FileTransferMethod, FileType - from graphon.file.models import File + from graphon.file import File, FileTransferMethod, FileType # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( @@ -367,8 +366,7 @@ def test_workflow_file_variable_with_signed_url(): def test_workflow_file_variable_remote_url(): """Test that File type variables with REMOTE_URL transfer method return the remote URL.""" - from graphon.file.enums import FileTransferMethod, FileType - from graphon.file.models import File + from graphon.file import File, FileTransferMethod, FileType # Create a File object with REMOTE_URL transfer method test_file = File( diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py index 5136922e88..9c9f8da87c 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from werkzeug.exceptions import Forbidden, NotFound from controllers.console import console_ns @@ -17,7 +18,6 @@ from controllers.console.datasets.rag_pipeline.datasource_auth import ( DatasourceUpdateProviderNameApi, ) from core.plugin.impl.oauth import OAuthHandler -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py index 63950736c5..6ef8ccfdbd 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Response +from graphon.variables.types import SegmentType from controllers.console import console_ns from controllers.console.app.error import DraftWorkflowNotExist @@ -15,7 +16,6 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable impor ) from controllers.web.error import InvalidArgumentError, NotFoundError from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from graphon.variables.types import SegmentType from models.account import Account diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index 472d133349..a3c0592d76 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -2,7 +2,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch import pytest -from werkzeug.exceptions import Forbidden, HTTPException, NotFound +from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound import services from controllers.console import console_ns @@ -26,6 +26,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import ( RagPipelineWorkflowLastRunApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from libs.datetime_utils import naive_utc_now from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError @@ -372,7 +373,7 @@ class TestPublishedPipelineApis: workflow = MagicMock( id="w1", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) session = MagicMock() @@ -691,6 +692,57 @@ class TestRagPipelineByIdApi: result, status = method(api, pipeline, "w1") assert status == 400 + def test_delete_success(self, app): + api = RagPipelineByIdApi() + method = unwrap(api.delete) + + pipeline = MagicMock(tenant_id="t1", workflow_id="active-workflow", id="pipeline-1") + + workflow_service = MagicMock() + + session = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + with ( + app.test_request_context("/", method="DELETE"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.WorkflowService", + return_value=workflow_service, + ), + ): + result = method(api, pipeline, "old-workflow") + + workflow_service.delete_workflow.assert_called_once_with( + session=session, + workflow_id="old-workflow", + tenant_id="t1", + ) + session.commit.assert_called_once() + assert result == (None, 204) + + def test_delete_active_workflow_rejected(self, app): + api = RagPipelineByIdApi() + method = unwrap(api.delete) + + pipeline = MagicMock(tenant_id="t1", workflow_id="active-workflow", id="pipeline-1") + + with app.test_request_context("/", method="DELETE"): + with pytest.raises(BadRequest, match="currently in use by pipeline"): + method(api, pipeline, "active-workflow") + class TestRagPipelineWorkflowLastRunApi: def test_last_run_success(self, app): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index 306a772fd1..693b06e95b 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -1,4 +1,3 @@ -from datetime import datetime from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -25,6 +24,7 @@ from controllers.console.datasets.error import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.rag.index_processor.constant.index_type import IndexStructureType +from libs.datetime_utils import naive_utc_now from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile @@ -54,8 +54,8 @@ def _segment(): disabled_by=None, status="normal", created_by="u1", - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), + created_at=naive_utc_now(), + updated_at=naive_utc_now(), updated_by="u1", indexing_at=None, completed_at=None, diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index e4acd91b76..710c9be684 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -20,7 +21,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from models.account import Account from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService diff --git a/api/tests/unit_tests/controllers/console/explore/test_audio.py b/api/tests/unit_tests/controllers/console/explore/test_audio.py index b4b57022e2..66c9ba48c5 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_audio.py +++ b/api/tests/unit_tests/controllers/console/explore/test_audio.py @@ -2,6 +2,7 @@ from io import BytesIO from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import controllers.console.explore.audio as audio_module @@ -19,7 +20,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index 145cc9cdd7..2e4ca4f2a4 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError, NotFound import controllers.console.explore.message as module @@ -21,7 +22,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from services.errors.conversation import ConversationNotExistsError from services.errors.message import ( FirstMessageNotExistsError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index 03eadcdb4e..04beb31389 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import controllers.console.explore.trial as module @@ -25,7 +26,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.errors.invoke import InvokeError from models import Account from models.account import TenantStatus from models.model import AppMode diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index b2f949c6e2..9c42ee9529 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -11,10 +11,9 @@ from unittest.mock import MagicMock import pytest from flask import Flask from flask.views import MethodView -from werkzeug.exceptions import Forbidden - from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from werkzeug.exceptions import Forbidden if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py index 168479af1e..fb9eec98cb 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic_core import ValidationError from werkzeug.exceptions import Forbidden @@ -13,7 +14,6 @@ from controllers.console.workspace.model_providers import ( ModelProviderValidateApi, PreferredProviderTypeUpdateApi, ) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" INVALID_UUID = "123" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py index f0d32f81fb..c829327bc7 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_models.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -2,6 +2,8 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from controllers.console.workspace.models import ( DefaultModelApi, @@ -14,8 +16,6 @@ from controllers.console.workspace.models import ( ModelProviderModelParameterRuleApi, ModelProviderModelValidateApi, ) -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError def unwrap(func): diff --git a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py index eb19243225..ce5fd1c466 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py @@ -90,8 +90,8 @@ class TestPluginListLatestVersionsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDebuggingKeyApi: @@ -120,8 +120,8 @@ class TestPluginDebuggingKeyApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginListApi: @@ -202,8 +202,9 @@ class TestPluginUploadFromPkgApi: patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_PACKAGE_SIZE", 0), patch("controllers.console.workspace.plugin.PluginService.upload_pkg") as upload_pkg_mock, ): - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: method(api) + assert "File size exceeds the maximum allowed size" in str(exc_info.value) upload_pkg_mock.assert_not_called() @@ -365,8 +366,8 @@ class TestPluginListInstallationsFromIdsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUploadFromGithubApi: @@ -401,8 +402,8 @@ class TestPluginUploadFromGithubApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUploadFromBundleApi: @@ -449,8 +450,9 @@ class TestPluginUploadFromBundleApi: patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_BUNDLE_SIZE", 0), patch("controllers.console.workspace.plugin.PluginService.upload_bundle") as upload_bundle_mock, ): - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: method(api) + assert "File size exceeds the maximum allowed size" in str(exc_info.value) upload_bundle_mock.assert_not_called() @@ -495,8 +497,8 @@ class TestPluginInstallFromGithubApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginInstallFromMarketplaceApi: @@ -532,8 +534,8 @@ class TestPluginInstallFromMarketplaceApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchMarketplacePkgApi: @@ -562,8 +564,8 @@ class TestPluginFetchMarketplacePkgApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchManifestApi: @@ -595,8 +597,8 @@ class TestPluginFetchManifestApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchInstallTasksApi: @@ -625,8 +627,8 @@ class TestPluginFetchInstallTasksApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchInstallTaskApi: @@ -655,8 +657,8 @@ class TestPluginFetchInstallTaskApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api, "t") + result = method(api, "t") + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDeleteInstallTaskApi: @@ -685,8 +687,8 @@ class TestPluginDeleteInstallTaskApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api, "t") + result = method(api, "t") + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDeleteAllInstallTaskItemsApi: @@ -717,8 +719,8 @@ class TestPluginDeleteAllInstallTaskItemsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDeleteInstallTaskItemApi: @@ -747,8 +749,8 @@ class TestPluginDeleteInstallTaskItemApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api, "task1", "item1") + result = method(api, "task1", "item1") + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUpgradeFromMarketplaceApi: @@ -790,8 +792,8 @@ class TestPluginUpgradeFromMarketplaceApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUpgradeFromGithubApi: @@ -839,8 +841,8 @@ class TestPluginUpgradeFromGithubApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: @@ -894,8 +896,8 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginChangePreferencesApi: diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index f5ebe0b534..b2d13dbbdf 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -1,4 +1,3 @@ -from datetime import datetime from io import BytesIO from unittest.mock import MagicMock, patch @@ -26,6 +25,7 @@ from controllers.console.workspace.workspace import ( WorkspacePermissionApi, ) from enums.cloud_plan import CloudPlan +from libs.datetime_utils import naive_utc_now from models.account import TenantStatus @@ -44,13 +44,13 @@ class TestTenantListApi: id="t1", name="Tenant 1", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) tenant2 = MagicMock( id="t2", name="Tenant 2", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) with ( @@ -97,13 +97,13 @@ class TestTenantListApi: id="t1", name="Tenant 1", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) tenant2 = MagicMock( id="t2", name="Tenant 2", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) features_t2 = MagicMock() @@ -152,13 +152,13 @@ class TestTenantListApi: id="t1", name="Tenant 1", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) tenant2 = MagicMock( id="t2", name="Tenant 2", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) features = MagicMock() @@ -204,7 +204,7 @@ class TestTenantListApi: id="t1", name="Tenant", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) features = MagicMock() @@ -243,13 +243,13 @@ class TestTenantListApi: id="t1", name="Tenant 1", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) tenant2 = MagicMock( id="t2", name="Tenant 2", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) with ( @@ -305,7 +305,7 @@ class TestWorkspaceListApi: api = WorkspaceListApi() method = unwrap(api.get) - tenant = MagicMock(id="t1", name="T", status="active", created_at=datetime.utcnow()) + tenant = MagicMock(id="t1", name="T", status="active", created_at=naive_utc_now()) paginate_result = MagicMock( items=[tenant], @@ -331,7 +331,7 @@ class TestWorkspaceListApi: id="t1", name="T", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) paginate_result = MagicMock( diff --git a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py index 5862239142..4a5f91cc5d 100644 --- a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py +++ b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py @@ -64,18 +64,18 @@ class TestGetActiveAccount: def test_returns_active_account(self, mock_db): mock_account = MagicMock() mock_account.status = "active" - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account + mock_db.session.scalar.return_value = mock_account result = _get_active_account("user@example.com") assert result is mock_account - mock_db.session.query.return_value.filter_by.assert_called_once_with(email="user@example.com") + mock_db.session.scalar.assert_called_once() @patch("controllers.inner_api.app.dsl.db") def test_returns_none_for_inactive_account(self, mock_db): mock_account = MagicMock() mock_account.status = "banned" - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account + mock_db.session.scalar.return_value = mock_account result = _get_active_account("banned@example.com") @@ -83,7 +83,7 @@ class TestGetActiveAccount: @patch("controllers.inner_api.app.dsl.db") def test_returns_none_for_nonexistent_email(self, mock_db): - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_db.session.scalar.return_value = None result = _get_active_account("missing@example.com") @@ -205,7 +205,7 @@ class TestEnterpriseAppDSLExport: @patch("controllers.inner_api.app.dsl.db") def test_export_success_returns_200(self, mock_db, mock_dsl_cls, api_instance, app: Flask): mock_app = MagicMock() - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app mock_dsl_cls.export_dsl.return_value = "version: 0.6.0\nkind: app\n" unwrapped = inspect.unwrap(api_instance.get) @@ -221,7 +221,7 @@ class TestEnterpriseAppDSLExport: @patch("controllers.inner_api.app.dsl.db") def test_export_with_secret(self, mock_db, mock_dsl_cls, api_instance, app: Flask): mock_app = MagicMock() - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app mock_dsl_cls.export_dsl.return_value = "yaml-data" unwrapped = inspect.unwrap(api_instance.get) @@ -234,7 +234,7 @@ class TestEnterpriseAppDSLExport: @patch("controllers.inner_api.app.dsl.db") def test_export_app_not_found_returns_404(self, mock_db, api_instance, app: Flask): - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_db.session.get.return_value = None unwrapped = inspect.unwrap(api_instance.get) with app.test_request_context("?include_secret=false"): diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index e81e612803..5a8cb4619f 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -13,6 +13,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -29,7 +30,6 @@ from controllers.service_api.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 3364c07e62..57681d8f5b 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -16,6 +16,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -34,7 +35,6 @@ from controllers.service_api.app.error import ( NotChatAppError, ) from core.errors.error import QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index 6543c27037..b1f036c6f3 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -19,6 +19,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.enums import WorkflowExecutionStatus from werkzeug.exceptions import BadRequest, NotFound from controllers.service_api.app.error import NotWorkflowAppError @@ -35,7 +36,6 @@ from controllers.service_api.app.workflow import ( WorkflowTaskStopApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError -from graphon.enums import WorkflowExecutionStatus from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py index eda270258d..4b8e3a738c 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -1,8 +1,9 @@ from types import SimpleNamespace -from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField from graphon.enums import WorkflowExecutionStatus +from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField + def test_workflow_run_status_field_with_enum() -> None: field = WorkflowRunStatusField() diff --git a/api/tests/unit_tests/controllers/web/test_audio.py b/api/tests/unit_tests/controllers/web/test_audio.py index a6ca441801..cbfc8fa613 100644 --- a/api/tests/unit_tests/controllers/web/test_audio.py +++ b/api/tests/unit_tests/controllers/web/test_audio.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.audio import AudioApi, TextApi from controllers.web.error import ( @@ -21,7 +22,6 @@ from controllers.web.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/web/test_completion.py b/api/tests/unit_tests/controllers/web/test_completion.py index 4f8d848637..49039d03fe 100644 --- a/api/tests/unit_tests/controllers/web/test_completion.py +++ b/api/tests/unit_tests/controllers/web/test_completion.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi from controllers.web.error import ( @@ -18,7 +19,6 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeError def _completion_app() -> SimpleNamespace: diff --git a/api/tests/unit_tests/core/agent/test_base_agent_runner.py b/api/tests/unit_tests/core/agent/test_base_agent_runner.py index 683cc0e36f..db4b293b16 100644 --- a/api/tests/unit_tests/core/agent/test_base_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_base_agent_runner.py @@ -621,7 +621,7 @@ class TestConvertDatasetRetrieverTool: class TestBaseAgentRunnerInit: def test_init_sets_stream_tool_call_and_files(self, mocker): session = mocker.MagicMock() - session.query.return_value.where.return_value.count.return_value = 2 + session.scalar.return_value = 2 mocker.patch.object(module.db, "session", session) mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[]) diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py index cde8820e00..bc7aea0ef9 100644 --- a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -2,11 +2,11 @@ import json from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage from core.agent.cot_agent_runner import CotAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError -from graphon.model_runtime.entities.llm_entities import LLMUsage class DummyRunner(CotAgentRunner): diff --git a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py index ea8cc8aa86..97206019b9 100644 --- a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py @@ -1,9 +1,9 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.agent.cot_chat_agent_runner import CotChatAgentRunner -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from tests.unit_tests.core.agent.conftest import ( DummyAgentConfig, DummyAppConfig, diff --git a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py index 2f5873d865..defc8b4b64 100644 --- a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py @@ -1,8 +1,6 @@ import json import pytest - -from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -10,6 +8,8 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) +from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner + # ----------------------------- # Fixtures # ----------------------------- diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py index 17ab5babcb..a44a0650eb 100644 --- a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -3,11 +3,6 @@ from typing import Any from unittest.mock import MagicMock import pytest - -from core.agent.errors import AgentMaxIterationError -from core.agent.fc_agent_runner import FunctionCallAgentRunner -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueMessageFileEvent from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.message_entities import ( DocumentPromptMessageContent, @@ -16,6 +11,11 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) +from core.agent.errors import AgentMaxIterationError +from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueMessageFileEvent + # ============================== # Dummy Helper Classes # ============================== diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py index 186b4a501d..5ee66da94a 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py @@ -2,6 +2,8 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.entities.model_entities import ModelStatus @@ -10,8 +12,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey class TestModelConfigConverter: diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py index d9fe7004ff..e2f3c16335 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py @@ -1,9 +1,9 @@ import pytest +from graphon.variables.input_entities import VariableEntityType from core.app.app_config.easy_ui_based_app.variables.manager import ( BasicVariablesConfigManager, ) -from graphon.variables.input_entities import VariableEntityType class TestBasicVariablesConfigManagerConvert: diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index 11fc15c94d..8bde9c1f97 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -1,7 +1,8 @@ -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from graphon.file.models import FileTransferMethod, FileUploadConfig, ImageConfig +from graphon.file import FileTransferMethod, FileUploadConfig, ImageConfig from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager + def test_convert_with_vision(): config = { diff --git a/api/tests/unit_tests/core/app/app_config/test_entities.py b/api/tests/unit_tests/core/app/app_config/test_entities.py index f2bc3076da..000f83cd5a 100644 --- a/api/tests/unit_tests/core/app/app_config/test_entities.py +++ b/api/tests/unit_tests/core/app/app_config/test_entities.py @@ -1,10 +1,10 @@ import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.app_config.entities import ( DatasetRetrieveConfigEntity, PromptTemplateEntity, ) -from graphon.variables.input_entities import VariableEntity, VariableEntityType class TestAppConfigEntities: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py index 8b0ff7b6c1..af5d203f12 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py @@ -9,10 +9,17 @@ from pydantic import BaseModel, ValidationError from constants import UUID_NIL from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator, _refresh_model +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.advanced_chat.generate_task_pipeline import ( + ConversationSnapshot, + MessageSnapshot, + WorkflowSnapshot, +) from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager +from libs.datetime_utils import naive_utc_now +from models.enums import MessageStatus from models.model import AppMode @@ -363,8 +370,15 @@ class TestAdvancedChatAppGeneratorInternals: workflow_run_id="run-id", ) + workflow = SimpleNamespace(id="wf-1", tenant_id="tenant", features={"feature": True}, features_dict={}) conversation = SimpleNamespace(id="conv-1", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) - message = SimpleNamespace(id="msg-1") + message = SimpleNamespace( + id="msg-1", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ) db_session = SimpleNamespace(commit=MagicMock(), refresh=MagicMock(), close=MagicMock()) captured: dict[str, object] = {} thread_data: dict[str, object] = {} @@ -394,19 +408,6 @@ class TestAdvancedChatAppGeneratorInternals: thread_data["started"] = True monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) monkeypatch.setattr( "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) ) @@ -424,7 +425,7 @@ class TestAdvancedChatAppGeneratorInternals: pause_state_config = SimpleNamespace(session_factory="session-factory", state_owner_user_id="owner") response = generator._generate( - workflow=SimpleNamespace(features={"feature": True}), + workflow=workflow, user=SimpleNamespace(id="user"), invoke_from=InvokeFrom.WEB_APP, application_generate_entity=application_generate_entity, @@ -444,6 +445,9 @@ class TestAdvancedChatAppGeneratorInternals: db_session.refresh.assert_called_once_with(conversation) db_session.close.assert_called_once() assert captured["draft_var_saver_factory"] == "draft-factory" + assert isinstance(captured["workflow"], WorkflowSnapshot) + assert isinstance(captured["conversation"], ConversationSnapshot) + assert isinstance(captured["message"], MessageSnapshot) def test_generate_internal_flow_with_existing_records_skips_init(self, monkeypatch): generator = AdvancedChatAppGenerator() @@ -464,8 +468,15 @@ class TestAdvancedChatAppGeneratorInternals: workflow_run_id="run-id", ) + workflow = SimpleNamespace(id="wf-2", tenant_id="tenant", features={}, features_dict={}) conversation = SimpleNamespace(id="conv-2", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) - message = SimpleNamespace(id="msg-2") + message = SimpleNamespace( + id="msg-2", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ) db_session = SimpleNamespace(close=MagicMock(), commit=MagicMock(), refresh=MagicMock()) init_records = MagicMock() thread_data: dict[str, object] = {} @@ -491,19 +502,6 @@ class TestAdvancedChatAppGeneratorInternals: thread_data["started"] = True monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) monkeypatch.setattr( "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) ) @@ -519,7 +517,7 @@ class TestAdvancedChatAppGeneratorInternals: ) response = generator._generate( - workflow=SimpleNamespace(features={}), + workflow=workflow, user=SimpleNamespace(id="user"), invoke_from=InvokeFrom.WEB_APP, application_generate_entity=application_generate_entity, @@ -940,10 +938,16 @@ class TestAdvancedChatAppGeneratorInternals: with pytest.raises(GenerateTaskStoppedError): generator._handle_advanced_chat_response( application_generate_entity=application_generate_entity, - workflow=SimpleNamespace(), + workflow=WorkflowSnapshot(id="wf", tenant_id="tenant", features_dict={}), queue_manager=SimpleNamespace(), - conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), - message=SimpleNamespace(id="msg"), + conversation=ConversationSnapshot(id="conv", mode=AppMode.ADVANCED_CHAT), + message=MessageSnapshot( + id="msg", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ), user=SimpleNamespace(), draft_var_saver_factory=lambda **kwargs: None, stream=False, @@ -981,10 +985,16 @@ class TestAdvancedChatAppGeneratorInternals: with pytest.raises(ValueError, match="other error"): generator._handle_advanced_chat_response( application_generate_entity=application_generate_entity, - workflow=SimpleNamespace(), + workflow=WorkflowSnapshot(id="wf", tenant_id="tenant", features_dict={}), queue_manager=SimpleNamespace(), - conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), - message=SimpleNamespace(id="msg"), + conversation=ConversationSnapshot(id="conv", mode=AppMode.ADVANCED_CHAT), + message=MessageSnapshot( + id="msg", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ), user=SimpleNamespace(), draft_var_saver_factory=lambda **kwargs: None, stream=False, @@ -992,31 +1002,6 @@ class TestAdvancedChatAppGeneratorInternals: logger_exception.assert_called_once() - def test_refresh_model_returns_detached_model(self, monkeypatch): - source_model = SimpleNamespace(id="source-id") - detached_model = SimpleNamespace(id="source-id", detached=True) - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def get(self, model_type, model_id): - _ = model_type - return detached_model if model_id == "source-id" else None - - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object())) - - refreshed = _refresh_model(session=SimpleNamespace(), model=source_model) - - assert refreshed is detached_model - def test_generate_worker_handles_invoke_auth_error(self, monkeypatch): generator = AdvancedChatAppGenerator() generator._dialogue_count = 1 diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index ef7df5e1da..061719d15a 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -3,12 +3,12 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 +from graphon.variables import SegmentType from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from factories import variable_factory -from graphon.variables import SegmentType from models import ConversationVariable, Workflow MINIMAL_GRAPH = { diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py index f2df35d7d0..e9fdeefee4 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator +from graphon.enums import WorkflowNodeExecutionStatus + from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.entities.task_entities import ( ChatbotAppBlockingResponse, @@ -10,7 +12,6 @@ from core.app.entities.task_entities import ( NodeStartStreamResponse, PingStreamResponse, ) -from graphon.enums import WorkflowNodeExecutionStatus class TestAdvancedChatGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py index 56919d7f65..a6d8598955 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py @@ -6,6 +6,8 @@ from types import SimpleNamespace from unittest import mock import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module from core.app.entities.app_invoke_entities import InvokeFrom @@ -17,11 +19,9 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import StreamEvent -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent -from models.model import EndUser +from models.model import AppMode, EndUser def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline: @@ -159,8 +159,8 @@ def test_resume_appends_chunks_to_paused_answer() -> None: task_id="task-1", ) queue_manager = SimpleNamespace(graph_runtime_state=None) - conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat") - message = SimpleNamespace( + conversation = pipeline_module.ConversationSnapshot(id="conversation-1", mode=AppMode.ADVANCED_CHAT) + message = pipeline_module.MessageSnapshot( id="message-1", created_at=datetime(2024, 1, 1), query="hello", @@ -170,7 +170,7 @@ def test_resume_appends_chunks_to_paused_answer() -> None: user = EndUser() user.id = "user-1" user.session_id = "session-1" - workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={}) + workflow = pipeline_module.WorkflowSnapshot(id="workflow-1", tenant_id="tenant-1", features_dict={}) pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline( application_generate_entity=application_generate_entity, @@ -184,14 +184,33 @@ def test_resume_appends_chunks_to_paused_answer() -> None: draft_var_saver_factory=SimpleNamespace(), ) - pipeline._get_message = mock.Mock(return_value=message) + stored_message = SimpleNamespace( + id="message-1", + answer="before", + status=MessageStatus.PAUSED, + updated_at=None, + provider_response_latency=0, + message_tokens=0, + message_unit_price=0, + message_price_unit=0, + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + total_price=0, + currency="USD", + message_metadata=None, + invoke_from=InvokeFrom.WEB_APP, + from_account_id=None, + from_end_user_id="user-1", + ) + pipeline._get_message = mock.Mock(return_value=stored_message) pipeline._recorded_files = [] list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after"))) pipeline._save_message(session=mock.Mock()) - assert message.answer == "beforeafter" - assert message.status == MessageStatus.NORMAL + assert stored_message.answer == "beforeafter" + assert stored_message.status == MessageStatus.NORMAL def test_workflow_succeeded_emits_message_end_before_workflow_finished() -> None: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index c78844d173..82b2e51019 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -1,13 +1,19 @@ from __future__ import annotations from contextlib import contextmanager -from datetime import datetime from types import SimpleNamespace import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig -from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.advanced_chat.generate_task_pipeline import ( + AdvancedChatAppGenerateTaskPipeline, + ConversationSnapshot, + MessageSnapshot, + WorkflowSnapshot, +) from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.queue_entities import ( QueueAdvancedChatMessageEndEvent, @@ -43,8 +49,7 @@ from core.app.entities.task_entities import ( ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables -from graphon.enums import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState, VariablePool +from libs.datetime_utils import naive_utc_now from models.enums import MessageStatus from models.model import AppMode, EndUser from tests.workflow_test_utils import build_test_variable_pool @@ -73,15 +78,15 @@ def _make_pipeline(): workflow_run_id="run-id", ) - message = SimpleNamespace( + message = MessageSnapshot( id="message-id", query="hello", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), status=MessageStatus.NORMAL, answer="", ) - conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT) - workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}) + conversation = ConversationSnapshot(id="conv-id", mode=AppMode.ADVANCED_CHAT) + workflow = WorkflowSnapshot(id="workflow-id", tenant_id="tenant", features_dict={}) user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session") pipeline = AdvancedChatAppGenerateTaskPipeline( @@ -257,7 +262,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) iter_next = QueueIterationNextEvent( @@ -273,7 +278,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_start = QueueLoopStartEvent( @@ -281,7 +286,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_next = QueueLoopNextEvent( @@ -297,7 +302,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) @@ -360,7 +365,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.LLM, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -370,7 +375,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.LLM, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -473,7 +478,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="title", - expiration_time=datetime.utcnow(), + expiration_time=naive_utc_now(), ) assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"] @@ -591,7 +596,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.LLM, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py index 80f7f94b1a..7dc4358150 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py @@ -1,12 +1,12 @@ import contextlib import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class DummyAccount: diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py index 4567b35480..08250bc3b6 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -1,10 +1,10 @@ import pytest +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.agent.entities import AgentEntity from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.moderation.base import ModerationError -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py index 8f3c41701b..68bcffb0e8 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -2,6 +2,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.chat.app_runner import ChatAppRunner @@ -9,7 +10,6 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.moderation.base import ModerationError -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index f56ca8de99..f255d2c7df 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent -from graphon.file.enums import FileTransferMethod, FileType -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index d6f7a05cdc..4a94a2b4f1 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -1,12 +1,11 @@ from types import SimpleNamespace import pytest +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.workflow.system_variables import build_system_variables from core.workflow.variable_pool_initializer import add_variables_to_pool -from graphon.runtime import GraphRuntimeState -from graphon.runtime.variable_pool import VariablePool def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 3ab63aed25..328cd12f12 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,9 +1,10 @@ from collections.abc import Mapping, Sequence -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from graphon.variables.segments import ArrayFileSegment, FileSegment +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter + class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Test class for WorkflowResponseConverter._fetch_files_from_variable_value method""" diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index e8946281ac..bc11bf4174 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -1,12 +1,13 @@ from datetime import UTC, datetime from types import SimpleNamespace +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter(): diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index 492e11ee0f..c9e146ff12 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -1,10 +1,11 @@ from types import SimpleNamespace +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter() -> WorkflowResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 7ee375d884..0fde7565d2 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -10,6 +10,8 @@ from typing import Any from unittest.mock import Mock import pytest +from graphon.entities import WorkflowStartReason +from graphon.enums import BuiltinNodeTypes from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -25,8 +27,6 @@ from core.app.entities.queue_entities import ( QueueNodeSucceededEvent, ) from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py index aa2085177e..619d66085a 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent import core.app.apps.completion.app_runner as module from core.app.apps.completion.app_runner import CompletionAppRunner from core.moderation.base import ModerationError -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py index f2e35f9900..96af9fbdee 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py @@ -3,13 +3,13 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError import core.app.apps.completion.app_generator as module from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py index cfe797aa76..6cdcab29ab 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( AppStreamResponse, @@ -10,7 +12,6 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus def test_convert_blocking_full_and_simple_response(): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py index 9db83f5531..4fe82efcb3 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py @@ -1,4 +1,5 @@ import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult import core.app.apps.pipeline.pipeline_queue_manager as module from core.app.apps.base_app_queue_manager import PublishFrom @@ -13,7 +14,6 @@ from core.app.entities.queue_entities import ( QueueWorkflowPartialSuccessEvent, QueueWorkflowSucceededEvent, ) -from graphon.model_runtime.entities.llm_entities import LLMResult def test_publish_sets_stop_listen_and_raises_on_stopped(mocker): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index fb19d6d761..ab70996f0a 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -22,11 +22,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.graph_events import GraphRunFailedEvent import core.app.apps.pipeline.pipeline_runner as module from core.app.apps.pipeline.pipeline_runner import PipelineRunner from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from graphon.graph_events import GraphRunFailedEvent def _build_app_generate_entity() -> SimpleNamespace: diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index b0f8b423e1..6167be3bbd 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,7 +1,7 @@ import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.apps.base_app_generator import BaseAppGenerator -from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_validate_inputs_with_zero(): @@ -476,8 +476,9 @@ class TestBaseAppGeneratorExtras: assert converted[1] == "event: ping\n\n" def test_get_draft_var_saver_factory_debugger(self): - from core.app.entities.app_invoke_entities import InvokeFrom from graphon.enums import BuiltinNodeTypes + + from core.app.entities.app_invoke_entities import InvokeFrom from models import Account base_app_generator = BaseAppGenerator() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py index 17de39ca99..1dee7fdab6 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -4,6 +4,15 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from core.app.app_config.entities import ( AdvancedChatMessageEntity, @@ -14,15 +23,6 @@ from core.app.app_config.entities import ( from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessageRole, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index 3673b7f68e..a126bc85f7 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -4,19 +4,14 @@ from types import ModuleType, SimpleNamespace from typing import Any import graphon.nodes.human_input.entities # noqa: F401 -from core.app.apps.advanced_chat import app_generator as adv_app_gen_module -from core.app.apps.workflow import app_generator as wf_app_gen_module -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason from graphon.entities.base_node_data import BaseNodeData, RetryConfig from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.entities.pause_reason import SchedulingPause -from graphon.entities.workflow_start_reason import WorkflowStartReason from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.graph_engine import GraphEngine -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from graphon.graph_engine.command_channels import InMemoryChannel from graphon.graph_events import ( GraphEngineEvent, GraphRunPausedEvent, @@ -30,6 +25,12 @@ from graphon.nodes.base.node import Node from graphon.nodes.end.entities import EndNodeData from graphon.nodes.start.entities import StartNodeData from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.apps.advanced_chat import app_generator as adv_app_gen_module +from core.app.apps.workflow import app_generator as wf_app_gen_module +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index 58c7bfa4bc..de5bca161c 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -4,23 +4,6 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest - -from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.app.entities.queue_entities import ( - QueueAgentLogEvent, - QueueIterationCompletedEvent, - QueueLoopCompletedEvent, - QueueNodeExceptionEvent, - QueueNodeFailedEvent, - QueueNodeRetryEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowPausedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.system_variables import default_system_variables from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import BuiltinNodeTypes from graphon.graph_events import ( @@ -41,6 +24,23 @@ from graphon.node_events import NodeRunResult from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.variables import StringVariable +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.queue_entities import ( + QueueAgentLogEvent, + QueueIterationCompletedEvent, + QueueLoopCompletedEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeSucceededEvent, + QueueTextChunkEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.workflow.system_variables import default_system_variables + class TestWorkflowBasedAppRunner: def test_resolve_user_from(self): diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py index 38a947986f..aa789d9ff3 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -1,11 +1,11 @@ from unittest.mock import MagicMock import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.queue_entities import QueueWorkflowPausedEvent -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph_events.graph import GraphRunPausedEvent class _DummyQueueManager: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 620a153204..9e30faecf2 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -4,14 +4,14 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.workflow.system_variables import default_system_variables -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.runtime import GraphRuntimeState, VariablePool from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index ef0edf4096..8a717e1dcc 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -3,6 +3,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from core.app.apps.common import workflow_response_converter from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -11,11 +16,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse from core.workflow.system_variables import build_system_variables -from graphon.entities.pause_reason import HumanInputRequired -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.graph_events.graph import GraphRunPausedEvent -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from models.account import Account from models.human_input import RecipientType diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py index 7dd7ffd727..b768e813bd 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( ErrorStreamResponse, @@ -9,7 +11,6 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus class TestWorkflowGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index a0a999cbc5..29df903aa8 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -2,14 +2,15 @@ import time from contextlib import contextmanager from unittest.mock import MagicMock +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState + from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.runtime import GraphRuntimeState from models.account import Account from models.model import AppMode from tests.workflow_test_utils import build_test_variable_pool diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index 601c3989b9..dabd2594b4 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -1,10 +1,11 @@ from __future__ import annotations from contextlib import contextmanager -from datetime import datetime from types import SimpleNamespace import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline @@ -45,8 +46,7 @@ from core.app.entities.task_entities import ( ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables, system_variables_to_mapping -from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool +from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import AppMode, EndUser from tests.workflow_test_utils import build_test_variable_pool @@ -192,7 +192,7 @@ class TestWorkflowGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -245,7 +245,7 @@ class TestWorkflowGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -303,7 +303,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) iter_next = QueueIterationNextEvent( @@ -319,7 +319,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_start = QueueLoopStartEvent( @@ -327,7 +327,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_next = QueueLoopNextEvent( @@ -343,7 +343,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) filled_event = QueueHumanInputFormFilledEvent( @@ -359,7 +359,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="title", - expiration_time=datetime.utcnow(), + expiration_time=naive_utc_now(), ) agent_event = QueueAgentLogEvent( id="log", @@ -648,7 +648,7 @@ class TestWorkflowGenerateTaskPipeline: node_title="title", node_type=BuiltinNodeTypes.LLM, node_run_index=1, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), provider_type="provider", provider_id="provider-id", error="error", @@ -660,7 +660,7 @@ class TestWorkflowGenerateTaskPipeline: node_title="title", node_type=BuiltinNodeTypes.LLM, node_run_index=1, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), provider_type="provider", provider_id="provider-id", ) @@ -685,7 +685,7 @@ class TestWorkflowGenerateTaskPipeline: node_execution_id="exec-id", node_id="node", node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -836,7 +836,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node-id", node_type=BuiltinNodeTypes.START, in_loop_id="loop-id", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), process_data={"k": "v"}, outputs={"out": 1}, ) diff --git a/api/tests/unit_tests/core/app/entities/test_task_entities.py b/api/tests/unit_tests/core/app/entities/test_task_entities.py index 7c79780641..014a0cba72 100644 --- a/api/tests/unit_tests/core/app/entities/test_task_entities.py +++ b/api/tests/unit_tests/core/app/entities/test_task_entities.py @@ -1,10 +1,11 @@ +from graphon.enums import WorkflowNodeExecutionStatus + from core.app.entities.task_entities import ( NodeFinishStreamResponse, NodeRetryStreamResponse, NodeStartStreamResponse, StreamEvent, ) -from graphon.enums import WorkflowNodeExecutionStatus class TestTaskEntities: diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index 28745a2091..a78c1b428f 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -1,17 +1,18 @@ from collections.abc import Sequence -from datetime import datetime from unittest.mock import Mock +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.command_channels import CommandChannel +from graphon.graph_events import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent +from graphon.node_events import NodeRunResult +from graphon.runtime import ReadOnlyGraphRuntimeState +from graphon.variables import StringVariable +from graphon.variables.segments import Segment, StringSegment + from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer from core.workflow.system_variables import SystemVariableKey from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph_engine.protocols.command_channel import CommandChannel -from graphon.graph_events.node import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent -from graphon.node_events import NodeRunResult -from graphon.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from graphon.variables import StringVariable -from graphon.variables.segments import Segment, StringSegment +from libs.datetime_utils import naive_utc_now class MockReadOnlyVariablePool: @@ -48,7 +49,7 @@ def _build_node_run_succeeded_event() -> NodeRunSucceededEvent: id="node-exec-id", node_id="assigner", node_type=BuiltinNodeTypes.LLM, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={}, diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 92a7788f6e..035e64325b 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -4,6 +4,17 @@ from time import time from unittest.mock import Mock import pytest +from graphon.entities.pause_reason import SchedulingPause +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import ( + GraphRunFailedEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, +) +from graphon.runtime import ReadOnlyVariablePool +from graphon.variables.segments import Segment from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity @@ -14,17 +25,6 @@ from core.app.layers.pause_state_persist_layer import ( _WorkflowGenerateEntityWrapper, ) from core.workflow.system_variables import SystemVariableKey -from graphon.entities.pause_reason import SchedulingPause -from graphon.graph_engine.entities.commands import GraphEngineCommand -from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from graphon.graph_events.graph import ( - GraphRunFailedEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) -from graphon.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool -from graphon.variables.segments import Segment from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py index 56705f1a7e..95931f4f8b 100644 --- a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py @@ -1,5 +1,6 @@ +from graphon.graph_events import GraphRunPausedEvent + from core.app.layers.suspend_layer import SuspendLayer -from graphon.graph_events.graph import GraphRunPausedEvent class TestSuspendLayer: diff --git a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py index 1ac9a4d8c0..7cf6eb4f31 100644 --- a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py @@ -1,7 +1,8 @@ from unittest.mock import Mock, patch -from core.app.layers.timeslice_layer import TimeSliceLayer from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand + +from core.app.layers.timeslice_layer import TimeSliceLayer from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import SchedulerCommand diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py index ecc431936c..aa9285789b 100644 --- a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -2,10 +2,11 @@ from datetime import UTC, datetime, timedelta from types import SimpleNamespace from unittest.mock import Mock, patch +from graphon.graph_events import GraphRunFailedEvent, GraphRunSucceededEvent +from graphon.runtime import VariablePool + from core.app.layers.trigger_post_layer import TriggerPostLayer from core.workflow.system_variables import build_system_variables -from graphon.graph_events.graph import GraphRunFailedEvent, GraphRunSucceededEvent -from graphon.runtime import VariablePool from models.enums import WorkflowTriggerStatus diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py index c246f7b783..58aa7d7478 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.entities.queue_entities import QueueErrorEvent from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.errors.error import QuotaExceededError -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py index 1c1bf391d3..4aaa10a81a 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -2,6 +2,8 @@ from types import SimpleNamespace from unittest.mock import ANY, Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity @@ -26,8 +28,6 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher from core.ops.ops_trace_manager import TraceQueueManager -from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py index ea000f3886..f7e7b7e20e 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -5,6 +5,9 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from core.app.app_config.entities import ( AppAdditionalFeatures, @@ -38,9 +41,6 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AudioTrunk -from graphon.file.enums import FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py index abfbcdb941..31b7313066 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -17,11 +17,11 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.file import FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.entities.task_entities import MessageEndStreamResponse from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline -from graphon.file.enums import FileTransferMethod, FileType from models.model import MessageFile, UploadFile diff --git a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py index 21c761c579..29df7eea86 100644 --- a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py +++ b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py @@ -1,9 +1,10 @@ from types import SimpleNamespace from unittest.mock import patch +from graphon.model_runtime.entities.model_entities import ModelPropertyKey + from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager from core.app.app_config.entities import ModelConfigEntity -from graphon.model_runtime.entities.model_entities import ModelPropertyKey from models.provider_ids import ModelProviderID diff --git a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py index 5c50cb78da..dc2d82ccd6 100644 --- a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py +++ b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py @@ -2,14 +2,14 @@ from datetime import UTC, datetime from unittest.mock import Mock import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType +from graphon.node_events import NodeRunResult from core.app.workflow.layers.persistence import ( PersistenceWorkflowInfo, WorkflowPersistenceLayer, _NodeRuntimeSnapshot, ) -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType -from graphon.node_events import NodeRunResult def _build_layer() -> WorkflowPersistenceLayer: diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py index cddd03f4b0..7be9d6ac1e 100644 --- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -8,13 +8,13 @@ from unittest.mock import MagicMock, patch from urllib.parse import parse_qs, urlparse import pytest +from graphon.file import File, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope from core.app.workflow import file_runtime from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime from core.workflow.file_reference import build_file_reference -from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py index c4bfb23272..8497261d45 100644 --- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -1,10 +1,10 @@ from types import SimpleNamespace import pytest +from graphon.enums import BuiltinNodeTypes from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.workflow.node_factory import DifyNodeFactory -from graphon.enums import BuiltinNodeTypes class DummyNode: diff --git a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py index 82552470a9..a47d3db6f5 100644 --- a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py +++ b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py @@ -2,9 +2,10 @@ from __future__ import annotations from types import SimpleNamespace -from core.app.workflow.layers.observability import ObservabilityLayer from graphon.enums import BuiltinNodeTypes +from core.app.workflow.layers.observability import ObservabilityLayer + class TestObservabilityLayerExtras: def test_init_tracer_enabled_sets_tracer(self, monkeypatch): diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py index 9863f34aba..d8a68f6d00 100644 --- a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -4,27 +4,21 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest - -from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity -from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.workflow.system_variables import SystemVariableKey, build_system_variables +from graphon.entities import WorkflowNodeExecution from graphon.entities.pause_reason import SchedulingPause -from graphon.entities.workflow_node_execution import WorkflowNodeExecution from graphon.enums import ( BuiltinNodeTypes, WorkflowExecutionStatus, WorkflowNodeExecutionStatus, WorkflowType, ) -from graphon.graph_events.graph import ( +from graphon.graph_events import ( GraphRunAbortedEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, -) -from graphon.graph_events.node import ( NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunPauseRequestedEvent, @@ -35,6 +29,10 @@ from graphon.graph_events.node import ( from graphon.node_events import NodeRunResult from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.workflow.system_variables import SystemVariableKey, build_system_variables + class _RepoRecorder: def __init__(self) -> None: diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py index 7b433ab57b..5ff9774b52 100644 --- a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -301,7 +301,6 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() - from core.app.entities.queue_entities import QueueAgentMessageEvent from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -309,6 +308,8 @@ class TestAppGeneratorTTSPublisher: TextPromptMessageContent, ) + from core.app.entities.queue_entities import QueueAgentMessageEvent + chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( @@ -336,10 +337,11 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() - from core.app.entities.queue_entities import QueueAgentMessageEvent from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from graphon.model_runtime.entities.message_entities import AssistantPromptMessage + from core.app.entities.queue_entities import QueueAgentMessageEvent + chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( diff --git a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py index b37c4c57a1..8e5670e9be 100644 --- a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py +++ b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py @@ -114,13 +114,9 @@ class TestOnToolEnd: document = mocker.Mock() document.metadata = {"document_id": "doc-1", "doc_id": "node-1"} - mock_query = mocker.Mock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - handler.on_tool_end([document]) - mock_query.update.assert_called_once() + mock_db.session.execute.assert_called_once() mock_db.session.commit.assert_called_once() def test_on_tool_end_non_parent_child_index(self, handler, mocker): @@ -138,13 +134,9 @@ class TestOnToolEnd: "dataset_id": "dataset-1", } - mock_query = mocker.Mock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - handler.on_tool_end([document]) - mock_query.update.assert_called_once() + mock_db.session.execute.assert_called_once() mock_db.session.commit.assert_called_once() def test_on_tool_end_empty_documents(self, handler): diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index af992e4e9f..b0c72ee42f 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -2,16 +2,15 @@ import types from collections.abc import Generator import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError from core.workflow.file_reference import parse_file_reference -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.file import File -from graphon.file.enums import FileTransferMethod, FileType -from graphon.node_events import StreamChunkEvent, StreamCompletedEvent def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py index 0b91d59953..fbaf6d497d 100644 --- a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py @@ -1,11 +1,10 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.file import File, FileTransferMethod, FileType from core.datasource.entities.datasource_entities import DatasourceMessage from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer -from graphon.file import File -from graphon.file.enums import FileTransferMethod, FileType from models.tools import ToolFile diff --git a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py index ef8f360dbf..ff9fd0d8f3 100644 --- a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py +++ b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py @@ -1,11 +1,12 @@ +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType + from core.entities.execution_extra_content import ( ExecutionExtraContentDomainModel, HumanInputContent, HumanInputFormDefinition, HumanInputFormSubmissionData, ) -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from models.execution_extra_content import ExecutionContentType diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py index a0b2820157..2acd278a31 100644 --- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -8,6 +8,9 @@ drive provider mapping behavior. """ import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.entities.model_entities import ( DefaultModelEntity, @@ -16,9 +19,6 @@ from core.entities.model_entities import ( ProviderModelWithStatusEntity, SimpleModelProviderEntity, ) -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity: diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index fe2c226843..8cf0409c4c 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -6,6 +6,17 @@ from typing import Any from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus @@ -24,17 +35,6 @@ from core.entities.provider_entities import ( SystemConfiguration, SystemConfigurationStatus, ) -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FieldModelSchema, - FormType, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderEntity, -) from models.enums import CredentialSourceType from models.provider import ProviderType from models.provider_ids import ModelProviderID diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py index a159d3ad4d..8685d16283 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py @@ -1,4 +1,5 @@ import pytest +from graphon.model_runtime.entities.model_entities import ModelType from core.entities.parameter_entities import AppSelectorScope from core.entities.provider_entities import ( @@ -8,7 +9,6 @@ from core.entities.provider_entities import ( ProviderQuotaType, ) from core.tools.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType def test_provider_quota_type_value_of_returns_enum_member() -> None: diff --git a/api/tests/unit_tests/core/helper/test_encrypter.py b/api/tests/unit_tests/core/helper/test_encrypter.py index 5890009742..f3ef7fccd0 100644 --- a/api/tests/unit_tests/core/helper/test_encrypter.py +++ b/api/tests/unit_tests/core/helper/test_encrypter.py @@ -38,13 +38,13 @@ class TestObfuscatedToken: class TestEncryptToken: - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_successful_encryption(self, mock_encrypt, mock_query): """Test successful token encryption""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_data" result = encrypt_token("tenant-123", "test_token") @@ -52,10 +52,10 @@ class TestEncryptToken: assert result == base64.b64encode(b"encrypted_data").decode() mock_encrypt.assert_called_with("test_token", "mock_public_key") - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") def test_tenant_not_found(self, mock_query): """Test error when tenant doesn't exist""" - mock_query.return_value.where.return_value.first.return_value = None + mock_query.return_value = None with pytest.raises(ValueError) as exc_info: encrypt_token("invalid-tenant", "test_token") @@ -119,7 +119,7 @@ class TestGetDecryptDecoding: class TestEncryptDecryptIntegration: - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") @patch("libs.rsa.decrypt") def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query): @@ -127,7 +127,7 @@ class TestEncryptDecryptIntegration: # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant # Setup mock encryption/decryption original_token = "test_token_123" @@ -146,14 +146,14 @@ class TestEncryptDecryptIntegration: class TestSecurity: """Critical security tests for encryption system""" - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_cross_tenant_isolation(self, mock_encrypt, mock_query): """Ensure tokens encrypted for one tenant cannot be used by another""" # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "tenant1_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_for_tenant1" # Encrypt token for tenant1 @@ -181,12 +181,12 @@ class TestSecurity: with pytest.raises(Exception, match="Decryption error"): decrypt_token("tenant-123", tampered) - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_encryption_randomness(self, mock_encrypt, mock_query): """Ensure same plaintext produces different ciphertext""" mock_tenant = MagicMock(encrypt_public_key="key") - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant # Different outputs for same input mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"] @@ -205,13 +205,13 @@ class TestEdgeCases: # Test empty string (which is a valid str type) assert obfuscated_token("") == "" - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query): """Test encryption of empty token""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_empty" result = encrypt_token("tenant-123", "") @@ -219,13 +219,13 @@ class TestEdgeCases: assert result == base64.b64encode(b"encrypted_empty").decode() mock_encrypt.assert_called_with("", "mock_public_key") - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query): """Test tokens containing special/unicode characters""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_special" # Test various special characters @@ -242,13 +242,13 @@ class TestEdgeCases: assert result == base64.b64encode(b"encrypted_special").decode() mock_encrypt.assert_called_with(token, "mock_public_key") - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query): """Test behavior when token exceeds RSA encryption limits""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant # RSA 2048-bit can only encrypt ~245 bytes # The actual limit depends on padding scheme diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py index 6ed9ddb476..b45f6fd9a7 100644 --- a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py @@ -2,20 +2,6 @@ import json from unittest.mock import MagicMock, patch import pytest - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import ( - ResponseFormat, - _handle_native_json_schema, - _handle_prompt_based_schema, - _parse_structured_output, - _prepare_schema_for_model, - _set_response_format, - convert_boolean_to_string, - invoke_llm_with_structured_output, - remove_additional_properties, -) -from core.model_manager import ModelInstance from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, @@ -31,6 +17,20 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import ( + ResponseFormat, + _handle_native_json_schema, + _handle_prompt_based_schema, + _parse_structured_output, + _prepare_schema_for_model, + _set_response_format, + convert_boolean_to_string, + invoke_llm_with_structured_output, + remove_additional_properties, +) +from core.model_manager import ModelInstance + class TestStructuredOutput: def test_remove_additional_properties(self): diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index b3a5885814..62e714deb6 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -2,12 +2,12 @@ import json from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator -from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError class TestLLMGenerator: @@ -314,8 +314,8 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None # Mock __instruction_modify_common call via invoke_llm mock_response = MagicMock() @@ -328,12 +328,12 @@ class TestLLMGenerator: assert result == {"modified": "prompt"} def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: last_run = MagicMock() last_run.query = "q" last_run.answer = "a" last_run.error = "e" - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run + mock_scalar.return_value = last_run mock_response = MagicMock() mock_response.message.get_text_content.return_value = '{"modified": "prompt"}' @@ -483,8 +483,8 @@ class TestLLMGenerator: def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity): # Testing placeholders replacement via instruction_modify_legacy for convenience - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = '{"ok": true}' @@ -504,8 +504,8 @@ class TestLLMGenerator: assert "current_val" in user_msg_dict["instruction"] def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = "No braces here" mock_model_instance.invoke_llm.return_value = mock_response @@ -516,8 +516,8 @@ class TestLLMGenerator: assert "Could not find a valid JSON object" in result["error"] def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = "[1, 2, 3]" mock_model_instance.invoke_llm.return_value = mock_response @@ -556,8 +556,8 @@ class TestLLMGenerator: ) def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed") result = LLMGenerator.instruction_modify_legacy( @@ -566,8 +566,8 @@ class TestLLMGenerator: assert "Failed to generate code" in result["error"] def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_model_instance.invoke_llm.side_effect = Exception("Random error") result = LLMGenerator.instruction_modify_legacy( @@ -576,8 +576,8 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = "No JSON here" diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index bfb1fde502..313d18c695 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -3,6 +3,7 @@ from unittest.mock import Mock, patch import jsonschema import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types @@ -18,7 +19,6 @@ from core.mcp.server.streamable_http import ( prepare_tool_arguments, process_mapping_response, ) -from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py index f459250b8e..9a5fb319d7 100644 --- a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py +++ b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py @@ -4,8 +4,6 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest - -from core.memory.token_buffer_memory import TokenBufferMemory from graphon.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -13,6 +11,8 @@ from graphon.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) + +from core.memory.token_buffer_memory import TokenBufferMemory from models.model import AppMode # --------------------------------------------------------------------------- diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py index 249ecb5006..6a672fdfd5 100644 --- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -1,7 +1,6 @@ from unittest.mock import Mock import pytest - from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.provider_entities import ( diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py index c2324fdec4..62d631a754 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py @@ -5,6 +5,8 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module @@ -34,8 +36,6 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey class RecordingTraceClient: diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py index fa885e9320..2d2be12f05 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py @@ -1,6 +1,8 @@ import json from unittest.mock import MagicMock +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, StatusCode from core.ops.aliyun_trace.entities.semconv import ( @@ -24,8 +26,6 @@ from core.ops.aliyun_trace.utils import ( serialize_json_data, ) from core.rag.models.document import Document -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser @@ -45,11 +45,8 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch): end_user_data = MagicMock(spec=EndUser) end_user_data.session_id = "session_id" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = end_user_data - mock_session = MagicMock() - mock_session.query.return_value = mock_query + mock_session.get.return_value = end_user_data from core.ops.aliyun_trace.utils import db @@ -63,11 +60,8 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch): message_data.from_account_id = "account_id" message_data.from_end_user_id = "end_user_id" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = None - mock_session = MagicMock() - mock_session.query.return_value = mock_query + mock_session.get.return_value = None from core.ops.aliyun_trace.utils import db diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py index fdf66d4d40..97f7a16327 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py @@ -5,6 +5,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( @@ -25,7 +26,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace -from graphon.enums import BuiltinNodeTypes from models import EndUser from models.enums import MessageStatus @@ -365,9 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_end_user - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.query", lambda model: mock_query) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock() trace_instance.add_generation = MagicMock() diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py index e89359c25b..bfe916f018 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py @@ -3,6 +3,7 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( @@ -21,7 +22,6 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser @@ -319,9 +319,7 @@ def test_message_trace(trace_instance, monkeypatch): # Mock EndUser lookup mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_end_user - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.query", lambda model: mock_query) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_run = MagicMock() diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py index 7ff6f7dcfd..f4c485a9fc 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py @@ -9,6 +9,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( @@ -21,7 +22,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds -from graphon.enums import BuiltinNodeTypes # ── Helpers ────────────────────────────────────────────────────────────────── @@ -330,7 +330,7 @@ class TestTraceDispatcher: class TestWorkflowTrace: def test_basic_workflow_no_nodes(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -343,7 +343,7 @@ class TestWorkflowTrace: span.end.assert_called_once() def test_workflow_filters_sys_inputs_and_adds_query(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -374,7 +374,7 @@ class TestWorkflowTrace: ), outputs='{"text": "hello world"}', ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [llm_node] + mock_db.session.scalars.return_value.all.return_value = [llm_node] workflow_span = MagicMock() node_span = MagicMock() @@ -397,7 +397,7 @@ class TestWorkflowTrace: } ), ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [qc_node] + mock_db.session.scalars.return_value.all.return_value = [qc_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -411,7 +411,7 @@ class TestWorkflowTrace: node_type=BuiltinNodeTypes.HTTP_REQUEST, process_data='{"url": "https://api.com"}', ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [http_node] + mock_db.session.scalars.return_value.all.return_value = [http_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -434,7 +434,7 @@ class TestWorkflowTrace: } ), ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [kr_node] + mock_db.session.scalars.return_value.all.return_value = [kr_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -448,7 +448,7 @@ class TestWorkflowTrace: def test_workflow_with_failed_node(self, trace_instance, mock_tracing, mock_db): failed_node = _make_node(status="failed") - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [failed_node] + mock_db.session.scalars.return_value.all.return_value = [failed_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -459,7 +459,7 @@ class TestWorkflowTrace: node_span.add_event.assert_called_once() def test_workflow_with_workflow_error(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] workflow_span = MagicMock() mock_tracing["start"].return_value = workflow_span mock_tracing["set"].return_value = "token" @@ -473,7 +473,7 @@ class TestWorkflowTrace: def test_workflow_node_no_inputs_no_outputs(self, trace_instance, mock_tracing, mock_db): node = _make_node(inputs=None, outputs=None) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [node] + mock_db.session.scalars.return_value.all.return_value = [node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -486,7 +486,7 @@ class TestWorkflowTrace: assert end_call.kwargs["outputs"] == {} def test_workflow_no_user_id_no_conversation_id(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -501,7 +501,7 @@ class TestWorkflowTrace: def test_workflow_empty_query(self, trace_instance, mock_tracing, mock_db): """When query is empty string, it's falsy so no query key added.""" - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -680,12 +680,12 @@ class TestGetMessageUserId: def test_returns_end_user_session_id(self, trace_instance, mock_db): end_user = MagicMock() end_user.session_id = "session-1" - mock_db.session.query.return_value.where.return_value.first.return_value = end_user + mock_db.session.get.return_value = end_user result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1"}) assert result == "session-1" def test_returns_account_id_when_no_end_user(self, trace_instance, mock_db): - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1", "from_account_id": "acc-1"}) assert result == "acc-1" @@ -834,7 +834,7 @@ class TestGenerateNameTrace: class TestGetWorkflowNodes: def test_queries_db(self, trace_instance, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = ["n1", "n2"] + mock_db.session.scalars.return_value.all.return_value = ["n1", "n2"] result = trace_instance._get_workflow_nodes("run-1") assert result == ["n1", "n2"] diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py index 6625cb719f..1cb32f2ee0 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py @@ -5,6 +5,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( @@ -18,7 +19,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser from models.enums import MessageStatus @@ -373,9 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_end_user - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.query", lambda model: mock_query) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2")) trace_instance.add_span = MagicMock() diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py index 6113e5c6c8..696f859b6f 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py @@ -1,6 +1,8 @@ from datetime import datetime from unittest.mock import MagicMock, patch +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import StatusCode from core.ops.entities.trace_entity import ( @@ -25,8 +27,6 @@ from core.ops.tencent_trace.entities.semconv import ( ) from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.rag.models.document import Document -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class TestTencentSpanBuilder: diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py index 265652381c..382e5dadc3 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py @@ -2,6 +2,8 @@ import logging from unittest.mock import MagicMock, patch import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( @@ -14,8 +16,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.tencent_trace.tencent_trace import TencentDataTrace -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes from models import Account, App, TenantAccountJoin logger = logging.getLogger(__name__) diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py index 4b925390d9..6b5cb5b09a 100644 --- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py +++ b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py @@ -1,7 +1,7 @@ +from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes from openinference.semconv.trace import OpenInferenceSpanKindValues from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind -from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes class TestGetNodeSpanKind: diff --git a/api/tests/unit_tests/core/ops/test_lookup_helpers.py b/api/tests/unit_tests/core/ops/test_lookup_helpers.py new file mode 100644 index 0000000000..86aa68643d --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_lookup_helpers.py @@ -0,0 +1,554 @@ +"""Unit tests for lookup helper functions in core.ops.ops_trace_manager. + +Covers: +- _lookup_app_and_workspace_names +- _lookup_credential_name +- _lookup_llm_credential_info +- TraceTask._get_user_id_from_metadata +""" + +from unittest.mock import MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_db_and_session_patches(scalar_side_effect=None, scalar_return_value=None): + """Return (mock_db, cm, session) ready to patch 'core.ops.ops_trace_manager.db' + and 'core.ops.ops_trace_manager.Session'. + + Provide either scalar_side_effect (list, for multiple calls) or + scalar_return_value (single value). + """ + mock_db = MagicMock() + mock_db.engine = MagicMock() + + session = MagicMock() + if scalar_side_effect is not None: + session.scalar.side_effect = scalar_side_effect + else: + session.scalar.return_value = scalar_return_value + + cm = MagicMock() + cm.__enter__ = MagicMock(return_value=session) + cm.__exit__ = MagicMock(return_value=False) + + return mock_db, cm, session + + +# --------------------------------------------------------------------------- +# _lookup_app_and_workspace_names +# --------------------------------------------------------------------------- + + +class TestLookupAppAndWorkspaceNames: + """Tests for _lookup_app_and_workspace_names(app_id, tenant_id).""" + + def test_both_found(self): + """Returns (app_name, workspace_name) when both records exist.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", "MyWorkspace"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "MyApp" + assert workspace_name == "MyWorkspace" + + def test_app_only_found(self): + """Returns (app_name, '') when tenant record is absent.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", None]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "MyApp" + assert workspace_name == "" + + def test_tenant_only_found(self): + """Returns ('', workspace_name) when app record is absent.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, "MyWorkspace"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "" + assert workspace_name == "MyWorkspace" + + def test_neither_found(self): + """Returns ('', '') when both DB lookups return None.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, None]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "" + assert workspace_name == "" + + def test_none_inputs_skips_db(self): + """Returns ('', '') immediately when both IDs are None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + app_name, workspace_name = _lookup_app_and_workspace_names(None, None) + + mock_session_cls.assert_not_called() + assert app_name == "" + assert workspace_name == "" + + def test_app_id_none_only_queries_tenant(self): + """When app_id is None, only the tenant query is issued.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyWorkspace") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names(None, "tenant-456") + + assert app_name == "" + assert workspace_name == "OnlyWorkspace" + assert session.scalar.call_count == 1 + + def test_tenant_id_none_only_queries_app(self): + """When tenant_id is None, only the app query is issued.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyApp") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", None) + + assert app_name == "OnlyApp" + assert workspace_name == "" + assert session.scalar.call_count == 1 + + +# --------------------------------------------------------------------------- +# _lookup_credential_name +# --------------------------------------------------------------------------- + + +class TestLookupCredentialName: + """Tests for _lookup_credential_name(credential_id, provider_type).""" + + @pytest.mark.parametrize("provider_type", ["builtin", "plugin", "api", "workflow", "mcp"]) + def test_known_provider_types_return_name(self, provider_type): + """Each valid provider_type results in a DB query and returns the credential name.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="CredentialA") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + result = _lookup_credential_name("cred-123", provider_type) + + assert result == "CredentialA" + session.scalar.assert_called_once() + + def test_credential_not_found_returns_empty_string(self): + """Returns '' when DB yields None for the given credential_id.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + result = _lookup_credential_name("cred-999", "api") + + assert result == "" + + def test_invalid_provider_type_returns_empty_string_without_db(self): + """Returns '' immediately for an unrecognised provider_type — no DB access.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + result = _lookup_credential_name("cred-123", "unknown_type") + + mock_session_cls.assert_not_called() + assert result == "" + + def test_none_credential_id_returns_empty_string_without_db(self): + """Returns '' immediately when credential_id is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + result = _lookup_credential_name(None, "api") + + mock_session_cls.assert_not_called() + assert result == "" + + def test_none_provider_type_returns_empty_string_without_db(self): + """Returns '' immediately when provider_type is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + result = _lookup_credential_name("cred-123", None) + + mock_session_cls.assert_not_called() + assert result == "" + + def test_builtin_and_plugin_map_to_same_model(self): + """Both 'builtin' and 'plugin' provider_types query BuiltinToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import BuiltinToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["builtin"] is BuiltinToolProvider + assert _PROVIDER_TYPE_TO_MODEL["plugin"] is BuiltinToolProvider + + def test_api_maps_to_api_tool_provider(self): + """'api' maps to ApiToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import ApiToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["api"] is ApiToolProvider + + def test_workflow_maps_to_workflow_tool_provider(self): + """'workflow' maps to WorkflowToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import WorkflowToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["workflow"] is WorkflowToolProvider + + def test_mcp_maps_to_mcp_tool_provider(self): + """'mcp' maps to MCPToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import MCPToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["mcp"] is MCPToolProvider + + +# --------------------------------------------------------------------------- +# _lookup_llm_credential_info +# --------------------------------------------------------------------------- + + +class TestLookupLlmCredentialInfo: + """Tests for _lookup_llm_credential_info(tenant_id, provider, model, model_type).""" + + def _provider_record(self, credential_id: str | None = None) -> MagicMock: + record = MagicMock() + record.credential_id = credential_id + return record + + def _model_record(self, credential_id: str | None = None) -> MagicMock: + record = MagicMock() + record.credential_id = credential_id + return record + + def test_model_level_credential_found(self): + """Returns model-level credential_id and name when ProviderModel has a credential.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id=None) + model_record = self._model_record(credential_id="model-cred-id") + + # scalar calls: (1) Provider, (2) ProviderModel, (3) ProviderModelCredential.credential_name + mock_db, cm, _session = _make_db_and_session_patches( + scalar_side_effect=[provider_record, model_record, "ModelCredName"] + ) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "model-cred-id" + assert cred_name == "ModelCredName" + + def test_provider_level_fallback_when_no_model_credential(self): + """Falls back to provider-level credential when ProviderModel has no credential_id.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + model_record = self._model_record(credential_id=None) + + # scalar calls: (1) Provider, (2) ProviderModel (no cred), (3) ProviderCredential.credential_name + mock_db, cm, _session = _make_db_and_session_patches( + scalar_side_effect=[provider_record, model_record, "ProvCredName"] + ) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "prov-cred-id" + assert cred_name == "ProvCredName" + + def test_provider_level_fallback_when_no_model_record(self): + """Falls back to provider-level credential when no ProviderModel row exists.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + + # scalar calls: (1) Provider, (2) ProviderModel → None, (3) ProviderCredential.credential_name + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, None, "ProvCredName"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "prov-cred-id" + assert cred_name == "ProvCredName" + + def test_no_model_arg_uses_provider_level_only(self): + """When model is None, skips ProviderModel query and uses provider credential.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + + # scalar calls: (1) Provider, (2) ProviderCredential.credential_name — no ProviderModel + mock_db, cm, session = _make_db_and_session_patches(scalar_side_effect=[provider_record, "ProvCredName"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", None) + + assert cred_id == "prov-cred-id" + assert cred_name == "ProvCredName" + assert session.scalar.call_count == 2 + + def test_provider_not_found_returns_none_and_empty(self): + """Returns (None, '') when Provider record does not exist.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id is None + assert cred_name == "" + + def test_none_tenant_id_returns_none_and_empty_without_db(self): + """Returns (None, '') immediately when tenant_id is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + cred_id, cred_name = _lookup_llm_credential_info(None, "openai", "gpt-4") + + mock_session_cls.assert_not_called() + assert cred_id is None + assert cred_name == "" + + def test_none_provider_returns_none_and_empty_without_db(self): + """Returns (None, '') immediately when provider is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", None, "gpt-4") + + mock_session_cls.assert_not_called() + assert cred_id is None + assert cred_name == "" + + def test_db_error_on_outer_query_returns_none_and_empty(self): + """Returns (None, '') and logs a warning when the outer DB query raises.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db, cm, session = _make_db_and_session_patches() + session.scalar.side_effect = Exception("DB connection failed") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id is None + assert cred_name == "" + + def test_credential_name_lookup_failure_returns_id_with_empty_name(self): + """When credential name sub-query fails, returns cred_id but '' for name.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + + # Provider found, no model record, then name lookup raises + mock_db, cm, _session = _make_db_and_session_patches( + scalar_side_effect=[provider_record, None, Exception("deleted")] + ) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "prov-cred-id" + assert cred_name == "" + + def test_no_credential_on_provider_or_model_returns_none_id(self): + """Returns (None, '') when neither provider nor model has a credential_id.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id=None) + model_record = self._model_record(credential_id=None) + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, model_record]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id is None + assert cred_name == "" + + +# --------------------------------------------------------------------------- +# TraceTask._get_user_id_from_metadata +# --------------------------------------------------------------------------- + + +class TestGetUserIdFromMetadata: + """Tests for TraceTask._get_user_id_from_metadata(metadata). + + Pure dict logic — no DB access required. + """ + + @pytest.fixture + def get_user_id(self): + """Return the classmethod under test.""" + from core.ops.ops_trace_manager import TraceTask + + return TraceTask._get_user_id_from_metadata + + def test_from_end_user_id_has_highest_priority(self, get_user_id): + """from_end_user_id takes precedence over all other keys.""" + metadata = { + "from_end_user_id": "eu-abc", + "from_account_id": "acc-xyz", + "user_id": "u-123", + } + assert get_user_id(metadata) == "end_user:eu-abc" + + def test_from_account_id_used_when_no_end_user(self, get_user_id): + """from_account_id is used when from_end_user_id is absent.""" + metadata = { + "from_account_id": "acc-xyz", + "user_id": "u-123", + } + assert get_user_id(metadata) == "account:acc-xyz" + + def test_user_id_used_when_no_end_user_or_account(self, get_user_id): + """user_id is used when both higher-priority keys are absent.""" + metadata = {"user_id": "u-123"} + assert get_user_id(metadata) == "user:u-123" + + def test_returns_anonymous_when_all_keys_absent(self, get_user_id): + """Returns 'anonymous' when metadata has none of the expected keys.""" + assert get_user_id({}) == "anonymous" + + def test_empty_string_end_user_id_is_skipped(self, get_user_id): + """Empty string for from_end_user_id is falsy and falls through to next key.""" + metadata = { + "from_end_user_id": "", + "from_account_id": "acc-xyz", + } + assert get_user_id(metadata) == "account:acc-xyz" + + def test_empty_string_account_id_is_skipped(self, get_user_id): + """Empty string for from_account_id is falsy and falls through to user_id.""" + metadata = { + "from_end_user_id": "", + "from_account_id": "", + "user_id": "u-123", + } + assert get_user_id(metadata) == "user:u-123" + + def test_empty_string_user_id_falls_through_to_anonymous(self, get_user_id): + """Empty string for user_id is falsy, so 'anonymous' is returned.""" + metadata = { + "from_end_user_id": "", + "from_account_id": "", + "user_id": "", + } + assert get_user_id(metadata) == "anonymous" + + def test_only_from_end_user_id_present(self, get_user_id): + """Minimal case: only from_end_user_id present.""" + assert get_user_id({"from_end_user_id": "eu-only"}) == "end_user:eu-only" + + def test_irrelevant_keys_do_not_interfere(self, get_user_id): + """Extra metadata keys have no effect on the result.""" + metadata = {"invoke_from": "web", "app_id": "a1"} + assert get_user_id(metadata) == "anonymous" diff --git a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py index 2d325ccb0e..e47df0121e 100644 --- a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py +++ b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py @@ -86,6 +86,7 @@ def make_message_data(**overrides): created_at = datetime(2025, 2, 20, 12, 0, 0) base = { "id": "msg-id", + "app_id": "app-id", "conversation_id": "conv-id", "created_at": created_at, "updated_at": created_at + timedelta(seconds=3), @@ -156,17 +157,19 @@ def make_workflow_run(): ) -def configure_db_query(session, *, message_file=None, workflow_app_log=None): - def _side_effect(model): - query = MagicMock() - query.filter_by.return_value.first.return_value = None - if message_file and model.__name__ == "MessageFile": - query.filter_by.return_value.first.return_value = message_file - if workflow_app_log and model.__name__ == "WorkflowAppLog": - query.filter_by.return_value.first.return_value = workflow_app_log - return query +def configure_db_scalar(session, *, message_file=None, workflow_app_log=None): + """Configure session.scalar to return appropriate values for MessageFile/WorkflowAppLog lookups.""" + original_scalar = session.scalar - session.query.side_effect = _side_effect + def _side_effect(stmt): + stmt_str = str(stmt) + if "message_file" in stmt_str.lower(): + return message_file + if "workflow_app_log" in stmt_str.lower(): + return workflow_app_log + return original_scalar(stmt) + + session.scalar.side_effect = _side_effect class DummySessionContext: @@ -182,6 +185,9 @@ class DummySessionContext: def __exit__(self, exc_type, exc_val, exc_tb): return False + def execute(self, *args, **kwargs): + return self + def scalar(self, *args, **kwargs): if self._index >= len(self._values): return None @@ -189,6 +195,12 @@ class DummySessionContext: self._index += 1 return value + def scalars(self, *args, **kwargs): + return self + + def all(self): + return [] + @pytest.fixture(autouse=True) def patch_provider_map(monkeypatch): @@ -253,7 +265,7 @@ def workflow_repo_fixture(monkeypatch): def trace_task_message(monkeypatch, mock_db): message_data = make_message_data() monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda msg_id: message_data) - configure_db_query(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id")) + configure_db_scalar(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id")) return message_data @@ -297,56 +309,53 @@ def test_obfuscated_decrypt_token(encryption_mocks): def test_get_decrypted_tracing_config_returns_config(encryption_mocks, mock_db): trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc", "other_value": "info"}) - mock_db.query.return_value.where.return_value.first.return_value = trace_config_data app = SimpleNamespace(id="app-id", tenant_id="tenant") - mock_db.scalar.return_value = app + mock_db.scalar.side_effect = [trace_config_data, app] decrypted = OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") assert decrypted["other_value"] == "info" def test_get_decrypted_tracing_config_missing_trace_config(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.scalar.return_value = None assert OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") is None def test_get_decrypted_tracing_config_raises_for_missing_app(mock_db): trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc"}) - mock_db.query.return_value.where.return_value.first.return_value = trace_config_data - mock_db.scalar.return_value = None + mock_db.scalar.side_effect = [trace_config_data, None] with pytest.raises(ValueError, match="App not found"): OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") def test_get_decrypted_tracing_config_raises_for_none_config(mock_db): trace_config_data = SimpleNamespace(tracing_config=None) - mock_db.query.return_value.where.return_value.first.return_value = trace_config_data - mock_db.scalar.return_value = SimpleNamespace(tenant_id="tenant") + mock_db.scalar.side_effect = [trace_config_data, SimpleNamespace(tenant_id="tenant")] with pytest.raises(ValueError, match="Tracing config cannot be None"): OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") def test_get_ops_trace_instance_handles_none_app(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.get.return_value = None assert OpsTraceManager.get_ops_trace_instance("app-id") is None def test_get_ops_trace_instance_returns_none_when_disabled(mock_db, monkeypatch): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": False})) - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app assert OpsTraceManager.get_ops_trace_instance("app-id") is None def test_get_ops_trace_instance_invalid_provider(mock_db, monkeypatch): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "missing"})) - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app monkeypatch.setattr("core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({})) assert OpsTraceManager.get_ops_trace_instance("app-id") is None def test_get_ops_trace_instance_success(monkeypatch, mock_db): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "dummy"})) - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app monkeypatch.setattr( "core.ops.ops_trace_manager.OpsTraceManager.get_decrypted_tracing_config", classmethod(lambda cls, aid, provider: {"secret_value": "decrypted", "other_value": "info"}), @@ -380,7 +389,7 @@ def test_get_app_config_through_message_id_app_model_config(mock_db): def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.get.return_value = None with pytest.raises(ValueError, match="Invalid tracing provider"): OpsTraceManager.update_app_tracing_config("app", True, "bad") with pytest.raises(ValueError, match="App not found"): @@ -389,26 +398,26 @@ def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch): def test_update_app_tracing_config_success(mock_db): app = SimpleNamespace(id="app-id", tracing="{}") - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app OpsTraceManager.update_app_tracing_config("app-id", True, "dummy") assert app.tracing is not None mock_db.commit.assert_called_once() def test_get_app_tracing_config_errors_when_missing(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.get.return_value = None with pytest.raises(ValueError, match="App not found"): OpsTraceManager.get_app_tracing_config("app") def test_get_app_tracing_config_returns_defaults(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=None) + mock_db.get.return_value = SimpleNamespace(tracing=None) assert OpsTraceManager.get_app_tracing_config("app-id") == {"enabled": False, "tracing_provider": None} def test_get_app_tracing_config_returns_payload(mock_db): payload = {"enabled": True, "tracing_provider": "dummy"} - mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=json.dumps(payload)) + mock_db.get.return_value = SimpleNamespace(tracing=json.dumps(payload)) assert OpsTraceManager.get_app_tracing_config("app-id") == payload @@ -454,7 +463,7 @@ def test_trace_task_message_trace(trace_task_message, mock_db): def test_trace_task_workflow_trace(workflow_repo_fixture, mock_db): DummySessionContext.scalar_values = ["wf-app-log", "message-ref"] - execution = SimpleNamespace(id_="run-id") + execution = SimpleNamespace(id_="run-id", total_tokens=0) task = TraceTask( trace_type=TraceTaskName.WORKFLOW_TRACE, workflow_execution=execution, conversation_id="conv", user_id="user" ) @@ -491,7 +500,7 @@ def test_trace_task_dataset_retrieval_trace(trace_task_message): def test_trace_task_tool_trace(monkeypatch, mock_db): custom_message = make_message_data(agent_thoughts=[make_agent_thought("tool-a", datetime(2025, 2, 20, 12, 1, 0))]) monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda _: custom_message) - configure_db_query(mock_db, message_file=FakeMessageFile()) + configure_db_scalar(mock_db, message_file=FakeMessageFile()) task = TraceTask(trace_type=TraceTaskName.TOOL_TRACE, message_id="msg-id") timer = {"start": 1, "end": 5} result = task.tool_trace("msg-id", timer, tool_name="tool-a", tool_inputs={"foo": 1}, tool_outputs="result") diff --git a/api/tests/unit_tests/core/ops/test_trace_queue_manager.py b/api/tests/unit_tests/core/ops/test_trace_queue_manager.py new file mode 100644 index 0000000000..a4903054e0 --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_trace_queue_manager.py @@ -0,0 +1,194 @@ +"""Unit tests for TraceQueueManager telemetry guard. + +Verifies that TraceQueueManager.add_trace_task() only enqueues tasks when at +least one consumer is active: +- Enterprise telemetry is enabled (_enterprise_telemetry_enabled=True), OR +- A third-party trace instance (Langfuse, etc.) is configured + +When neither is active, tasks are silently dropped to avoid unnecessary work. + +When BOTH are false, tasks are silently dropped (correct behavior). +""" + +import queue +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def trace_queue_manager_and_task(monkeypatch): + """Fixture to provide TraceQueueManager and TraceTask with delayed imports.""" + module_name = "core.ops.ops_trace_manager" + if module_name not in sys.modules: + ops_stub = types.ModuleType(module_name) + + class StubTraceTask: + def __init__(self, trace_type): + self.trace_type = trace_type + self.app_id = None + + class StubTraceQueueManager: + def __init__(self, app_id=None): + self.app_id = app_id + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled() + self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id) + + def add_trace_task(self, trace_task): + if self._enterprise_telemetry_enabled or self.trace_instance: + trace_task.app_id = self.app_id + from core.ops.ops_trace_manager import trace_manager_queue + + trace_manager_queue.put(trace_task) + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id): + return None + + ops_stub.TraceQueueManager = StubTraceQueueManager + ops_stub.TraceTask = StubTraceTask + ops_stub.OpsTraceManager = StubOpsTraceManager + ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue) + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + from core.ops.entities.trace_entity import TraceTaskName + + ops_module = __import__(module_name, fromlist=["TraceQueueManager", "TraceTask"]) + TraceQueueManager = ops_module.TraceQueueManager + TraceTask = ops_module.TraceTask + + return TraceQueueManager, TraceTask, TraceTaskName + + +class TestTraceQueueManagerTelemetryGuard: + """Test TraceQueueManager's telemetry guard in add_trace_task().""" + + def test_task_not_enqueued_when_telemetry_disabled_and_no_trace_instance(self, trace_queue_manager_and_task): + """Verify task is NOT enqueued when telemetry disabled and no trace instance. + + This is the core guard: when _enterprise_telemetry_enabled=False AND + trace_instance=None, the task should be silently dropped. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_not_called() + + def test_task_enqueued_when_telemetry_enabled(self, trace_queue_manager_and_task): + """Verify task IS enqueued when enterprise telemetry is enabled. + + When _enterprise_telemetry_enabled=True, the task should be enqueued + regardless of trace_instance state. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_task_enqueued_when_trace_instance_configured(self, trace_queue_manager_and_task): + """Verify task IS enqueued when third-party trace instance is configured. + + When trace_instance is not None (e.g., Langfuse configured), the task + should be enqueued even if enterprise telemetry is disabled. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + mock_trace_instance = MagicMock() + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False), + patch( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance + ), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_task_enqueued_when_both_telemetry_and_trace_instance_enabled(self, trace_queue_manager_and_task): + """Verify task IS enqueued when both telemetry and trace instance are enabled. + + When both _enterprise_telemetry_enabled=True AND trace_instance is set, + the task should definitely be enqueued. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + mock_trace_instance = MagicMock() + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance + ), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_app_id_set_before_enqueue(self, trace_queue_manager_and_task): + """Verify app_id is set on the task before enqueuing. + + The guard logic sets trace_task.app_id = self.app_id before calling + trace_manager_queue.put(trace_task). This test verifies that behavior. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="expected-app-id") + manager.add_trace_task(trace_task) + + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "expected-app-id" diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py index 8987b6682c..5014f40afc 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py @@ -7,6 +7,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from weave.trace_server.trace_server_interface import TraceStatus from core.ops.entities.config_entity import WeaveConfig @@ -22,7 +23,6 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.ops.weave_trace.weave_trace import WeaveDataTrace -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey # ── Helpers ────────────────────────────────────────────────────────────────── @@ -802,8 +802,8 @@ class TestMessageTrace: def test_basic_message_trace(self, trace_instance, monkeypatch): """message_trace creates message run and llm child run.""" monkeypatch.setattr( - "core.ops.weave_trace.weave_trace.db.session.query", - lambda model: MagicMock(where=lambda: MagicMock(first=lambda: None)), + "core.ops.weave_trace.weave_trace.db.session.get", + lambda model, pk: None, ) trace_instance.start_call = MagicMock() @@ -823,7 +823,7 @@ class TestMessageTrace: trace_instance.file_base_url = "http://files.test" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -845,7 +845,7 @@ class TestMessageTrace: end_user.session_id = "session-xyz" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = end_user + mock_db.session.get.return_value = end_user monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -865,7 +865,7 @@ class TestMessageTrace: def test_message_trace_no_end_user(self, trace_instance, monkeypatch): """message_trace handles when from_end_user_id is None.""" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -883,7 +883,7 @@ class TestMessageTrace: def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch): """trace_id falls back to message_id when trace_id is None.""" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -898,7 +898,7 @@ class TestMessageTrace: def test_message_trace_file_list_none(self, trace_instance, monkeypatch): """message_trace handles file_list=None gracefully.""" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py index c2778f082b..3feb4159ad 100644 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py @@ -332,27 +332,21 @@ class TestPluginAppBackwardsInvocation: PluginAppBackwardsInvocation._get_user("uid") def test_get_app_returns_app(self, mocker): - query_chain = MagicMock() - query_chain.where.return_value = query_chain app_obj = MagicMock(id="app") - query_chain.first.return_value = app_obj - db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=app_obj))) mocker.patch("core.plugin.backwards_invocation.app.db", db) assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj def test_get_app_raises_when_missing(self, mocker): - query_chain = MagicMock() - query_chain.where.return_value = query_chain - query_chain.first.return_value = None - db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=None))) mocker.patch("core.plugin.backwards_invocation.app.db", db) with pytest.raises(ValueError, match="app not found"): PluginAppBackwardsInvocation._get_app("app", "tenant") def test_get_app_raises_when_query_fails(self, mocker): - db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down")))) + db = SimpleNamespace(session=MagicMock(scalar=MagicMock(side_effect=RuntimeError("db down")))) mocker.patch("core.plugin.backwards_invocation.app.db", db) with pytest.raises(ValueError, match="app not found"): diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py index c24d3ac012..543b278715 100644 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py @@ -1,9 +1,10 @@ from types import SimpleNamespace from unittest.mock import patch +from graphon.model_runtime.entities.message_entities import UserPromptMessage + from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation from core.plugin.entities.request import RequestInvokeSummary -from graphon.model_runtime.entities.message_entities import UserPromptMessage def test_system_model_helpers_forward_user_id(): diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py index 68aa130518..f8d0e127b1 100644 --- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -6,15 +6,15 @@ from types import SimpleNamespace from unittest.mock import Mock, sentinel import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl import model_runtime as model_runtime_module from core.plugin.impl.model import PluginModelClient from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_schema() -> AIModelEntity: diff --git a/api/tests/unit_tests/core/plugin/test_plugin_entities.py b/api/tests/unit_tests/core/plugin/test_plugin_entities.py index f1c4c7e700..a812b01c5b 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_entities.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_entities.py @@ -4,6 +4,12 @@ from enum import StrEnum import pytest from flask import Response +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) from pydantic import ValidationError from core.plugin.entities.endpoint import EndpointEntityWithInstance @@ -25,12 +31,6 @@ from core.plugin.entities.request import ( ) from core.plugin.utils.http_parser import serialize_response from core.tools.entities.common_entities import I18nObject -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) class TestEndpointEntity: diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index eae9d9459e..3063ca0197 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -17,6 +17,14 @@ from unittest.mock import MagicMock, patch import httpx import pytest +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from core.plugin.entities.plugin_daemon import ( @@ -26,6 +34,7 @@ from core.plugin.entities.plugin_daemon import ( from core.plugin.impl.base import BasePluginClient from core.plugin.impl.exc import ( PluginDaemonBadRequestError, + PluginDaemonClientSideError, PluginDaemonInternalServerError, PluginDaemonNotFoundError, PluginDaemonUnauthorizedError, @@ -36,14 +45,6 @@ from core.plugin.impl.exc import ( ) from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError class TestPluginRuntimeExecution: @@ -557,7 +558,7 @@ class TestPluginRuntimeErrorHandling: with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert - with pytest.raises(httpx.HTTPStatusError): + with pytest.raises(PluginDaemonInternalServerError): plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool) def test_empty_data_response_error(self, plugin_client, mock_config): @@ -1808,8 +1809,8 @@ class TestPluginInstallerAdvanced: mock_response.raise_for_status = raise_for_status with patch("httpx.request", return_value=mock_response, autospec=True): - # Act & Assert - Should raise HTTPStatusError for 404 - with pytest.raises(httpx.HTTPStatusError): + # Act & Assert - Should raise PluginDaemonClientSideError for 404 + with pytest.raises(PluginDaemonClientSideError): installer.fetch_plugin_readme("test-tenant", "test-org/test-plugin", "en") def test_list_plugins_with_pagination(self, installer, mock_config): diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index 4d4313dd84..90730dff5a 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -1,13 +1,12 @@ from collections.abc import Generator import pytest +from graphon.file import File, FileTransferMethod, FileType from core.agent.entities import AgentInvokeMessage from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File class TestChunkMerger: diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 395d392127..2b280dd674 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,13 +2,6 @@ from typing import cast from unittest.mock import MagicMock, patch import pytest - -from configs import dify_config -from core.app.app_config.entities import ModelConfigEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -18,6 +11,13 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) + +from configs import dify_config +from core.app.app_config.entities import ModelConfigEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 803afa54d7..4a54649b28 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -1,11 +1,5 @@ from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import ( - ModelConfigWithCredentialsEntity, -) -from core.entities.provider_configuration import ProviderModelBundle -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, @@ -13,6 +7,13 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +from core.app.entities.app_invoke_entities import ( + ModelConfigWithCredentialsEntity, +) +from core.entities.provider_configuration import ProviderModelBundle +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index 5d865d934c..a4b3960b0a 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,5 +1,3 @@ -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -9,6 +7,9 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil + def test_build_prompt_message_with_prompt_message_contents(): prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")]) diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 9f9ea33695..e35ce2c48a 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -2,9 +2,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.prompt.prompt_transform import PromptTransform -from graphon.model_runtime.entities.model_entities import ModelPropertyKey # from core.app.app_config.entities import ModelConfigEntity # from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index 0dc74b33df..3f188cfbb4 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -2,6 +2,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -18,12 +24,6 @@ from core.prompt.prompt_templates.advanced_prompt_templates import ( CONTEXT, ) from core.prompt.simple_prompt_transform import SimplePromptTransform -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - TextPromptMessageContent, - UserPromptMessage, -) from models.model import AppMode, Conversation diff --git a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py index 1f3247590c..006b4e7345 100644 --- a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py +++ b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py @@ -1,12 +1,13 @@ from unittest.mock import MagicMock, patch +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError + from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError def _doc(content: str) -> Document: diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py index bfa78fe565..6fd44be4d4 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -12,11 +12,11 @@ from unittest.mock import Mock, patch import numpy as np import pytest +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from sqlalchemy.exc import IntegrityError from core.rag.embedding.cached_embedding import CacheEmbedding -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 392f0b458b..d7ba944e58 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -49,10 +49,6 @@ from unittest.mock import Mock, patch import numpy as np import pytest -from sqlalchemy.exc import IntegrityError - -from core.entities.embedding_type import EmbeddingInputType -from core.rag.embedding.cached_embedding import CacheEmbedding from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from graphon.model_runtime.errors.invoke import ( @@ -60,6 +56,10 @@ from graphon.model_runtime.errors.invoke import ( InvokeConnectionError, InvokeRateLimitError, ) +from sqlalchemy.exc import IntegrityError + +from core.entities.embedding_type import EmbeddingInputType +from core.rag.embedding.cached_embedding import CacheEmbedding from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index c861871f02..cc2873dd3f 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -2,14 +2,14 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelFeature from core.entities.knowledge_entities import PreviewDetail from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.models.document import AttachmentDocument, Document -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent -from graphon.model_runtime.entities.model_entities import ModelFeature class TestParagraphIndexProcessor: diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index 059876d410..450e716636 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -53,6 +53,7 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm.exc import ObjectDeletedError from core.errors.error import ProviderTokenNotInitError @@ -63,7 +64,6 @@ from core.indexing_runner import ( ) from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document -from graphon.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index 415597f336..2ec7f0498e 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -17,6 +17,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_manager import ModelInstance from core.rag.index_processor.constant.doc_type import DocType @@ -28,7 +29,6 @@ from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner -from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult def create_mock_model_instance() -> ModelInstance: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index a7e62e7b0a..c11426163e 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -6,6 +6,8 @@ from uuid import uuid4 import pytest from flask import Flask, current_app +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelFeature from sqlalchemy import column from core.app.app_config.entities import ( @@ -35,8 +37,6 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.entities.model_entities import ModelFeature from models.dataset import Dataset from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py index 43c521dcfd..5a2ecb8220 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py @@ -1,8 +1,9 @@ from unittest.mock import Mock -from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter from graphon.model_runtime.entities.llm_entities import LLMUsage +from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter + class TestFunctionCallMultiDatasetRouter: def test_invoke_returns_none_when_no_tools(self) -> None: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py index c56528cf55..539ac0f849 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -1,12 +1,13 @@ from types import SimpleNamespace from unittest.mock import Mock, patch -from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish -from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.message_entities import PromptMessageRole from graphon.model_runtime.entities.model_entities import ModelType +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish +from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter + class TestReactMultiDatasetRouter: def test_invoke_returns_none_when_no_tools(self) -> None: diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index 2735ec512f..e229d5fc1a 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -9,9 +9,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowType from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository -from graphon.entities.workflow_execution import WorkflowExecution, WorkflowType from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 05b4f3a053..7dbf78d0f0 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -9,14 +9,14 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest - -from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from core.repositories.factory import OrderConfig from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) from graphon.enums import BuiltinNodeTypes + +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 18805bac59..0fc82dda53 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -7,6 +7,11 @@ from datetime import datetime from types import SimpleNamespace import pytest +from graphon.nodes.human_input.entities import ( + FormDefinition, + UserAction, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( HumanInputFormRecord, @@ -21,11 +26,6 @@ from core.workflow.human_input_compat import ( ExternalRecipient, MemberRecipient, ) -from graphon.nodes.human_input.entities import ( - FormDefinition, - UserAction, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, @@ -274,7 +274,7 @@ def _make_form_definition() -> str: inputs=[], user_actions=[UserAction(id="submit", title="Submit")], rendered_content="

hello

", - expiration_time=datetime.utcnow(), + expiration_time=naive_utc_now(), ).model_dump_json() diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py index 1297a95df1..8ff0e40587 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -9,6 +9,8 @@ from typing import Any from unittest.mock import MagicMock import pytest +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( FormCreateParams, @@ -29,8 +31,6 @@ from core.workflow.human_input_compat import ( MemberRecipient, WebAppDeliveryMethod, ) -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import HumanInputFormRecipient, RecipientType diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py index 6cb3c3c6ac..e5c3e85487 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -3,11 +3,12 @@ from unittest.mock import MagicMock from uuid import uuid4 import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from graphon.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType from models import Account, CreatorUserRole, EndUser, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py index 6af7b02d4c..5b4d26b780 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -10,6 +10,12 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) from sqlalchemy import Engine, create_engine from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker @@ -23,12 +29,6 @@ from core.repositories.sqlalchemy_workflow_node_execution_repository import ( _find_first, _replace_or_append_offload, ) -from graphon.entities import WorkflowNodeExecution -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) from models import Account, EndUser from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py index abdbc72085..84fe522388 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py @@ -4,17 +4,17 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index 5af1376a0a..27729e7f06 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -11,17 +11,17 @@ from datetime import UTC, datetime from typing import Any from unittest.mock import MagicMock +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from sqlalchemy import Engine from configs import dify_config from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionTriggeredFrom from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/tests/unit_tests/core/telemetry/test_facade.py b/api/tests/unit_tests/core/telemetry/test_facade.py new file mode 100644 index 0000000000..36e8e1bbb1 --- /dev/null +++ b/api/tests/unit_tests/core/telemetry/test_facade.py @@ -0,0 +1,181 @@ +"""Unit tests for core.telemetry.emit() routing and enterprise-only filtering.""" + +from __future__ import annotations + +import queue +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.events import TelemetryContext, TelemetryEvent + + +@pytest.fixture +def telemetry_test_setup(monkeypatch): + module_name = "core.ops.ops_trace_manager" + ops_stub = types.ModuleType(module_name) + + class StubTraceTask: + def __init__(self, trace_type, **kwargs): + self.trace_type = trace_type + self.app_id = None + self.kwargs = kwargs + + class StubTraceQueueManager: + def __init__(self, app_id=None, user_id=None): + self.app_id = app_id + self.user_id = user_id + self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id) + + def add_trace_task(self, trace_task): + trace_task.app_id = self.app_id + from core.ops.ops_trace_manager import trace_manager_queue + + trace_manager_queue.put(trace_task) + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id): + return None + + ops_stub.TraceQueueManager = StubTraceQueueManager + ops_stub.TraceTask = StubTraceTask + ops_stub.OpsTraceManager = StubOpsTraceManager + ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue) + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + from core.telemetry import emit + + return emit, ops_stub.trace_manager_queue + + +class TestTelemetryEmit: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_enterprise_trace_creates_trace_task(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={"key": "value"}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + + def test_emit_community_trace_enqueued(self, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.WORKFLOW_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + + def test_emit_enterprise_only_trace_dropped_when_ee_disabled(self, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_not_called() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_all_enterprise_only_traces_allowed_when_ee_enabled(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + enterprise_only_traces = [ + TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + TraceTaskName.NODE_EXECUTION_TRACE, + TraceTaskName.PROMPT_GENERATION_TRACE, + ] + + for trace_name in enterprise_only_traces: + mock_queue.reset_mock() + + event = TelemetryEvent( + name=trace_name, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == trace_name + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_passes_name_directly_to_trace_task(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={"extra": "data"}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + assert isinstance(called_task.trace_type, TraceTaskName) + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_with_provided_trace_manager(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + mock_trace_manager = MagicMock() + mock_trace_manager.add_trace_task = MagicMock() + + event = TelemetryEvent( + name=TraceTaskName.NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event, trace_manager=mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + called_task = mock_trace_manager.add_trace_task.call_args[0][0] + assert called_task.trace_type == TraceTaskName.NODE_EXECUTION_TRACE diff --git a/api/tests/unit_tests/core/telemetry/test_gateway_integration.py b/api/tests/unit_tests/core/telemetry/test_gateway_integration.py new file mode 100644 index 0000000000..a68fce5e7f --- /dev/null +++ b/api/tests/unit_tests/core/telemetry/test_gateway_integration.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from core.telemetry.gateway import emit, is_enterprise_telemetry_enabled +from enterprise.telemetry.contracts import TelemetryCase + + +class TestTelemetryCoreExports: + def test_is_enterprise_telemetry_enabled_exported(self) -> None: + from core.telemetry.gateway import is_enterprise_telemetry_enabled as exported_func + + assert callable(exported_func) + + +@pytest.fixture +def mock_ops_trace_manager(): + mock_module = MagicMock() + mock_trace_task_class = MagicMock() + mock_trace_task_class.return_value = MagicMock() + mock_module.TraceTask = mock_trace_task_class + mock_module.TraceQueueManager = MagicMock() + + mock_trace_entity = MagicMock() + mock_trace_task_name = MagicMock() + mock_trace_task_name.return_value = "workflow" + mock_trace_entity.TraceTaskName = mock_trace_task_name + + with ( + patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}), + patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}), + ): + yield mock_module, mock_trace_entity + + +class TestGatewayIntegrationTraceRouting: + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_ce_eligible_trace_routed_to_trace_manager( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True): + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_ce_eligible_trace_routed_when_ee_disabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_enterprise_only_trace_dropped_when_ee_disabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_enterprise_only_trace_routed_when_ee_enabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayIntegrationMetricRouting: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_metric_case_routes_to_celery_task( + self, + mock_ee_enabled: MagicMock, + ) -> None: + from enterprise.telemetry.contracts import TelemetryEnvelope + + with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc", "name": "My App"} + + emit(TelemetryCase.APP_CREATED, context, payload) + + mock_delay.assert_called_once() + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.case == TelemetryCase.APP_CREATED + assert envelope.tenant_id == "tenant-123" + assert envelope.payload["app_id"] == "app-abc" + + @pytest.mark.usefixtures("mock_ops_trace_manager") + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_tool_execution_trace_routed( + self, + mock_ee_enabled: MagicMock, + ) -> None: + mock_trace_manager = MagicMock() + context = {"tenant_id": "tenant-123", "app_id": "app-123"} + payload = {"tool_name": "test_tool", "tool_inputs": {}, "tool_outputs": "result"} + + emit(TelemetryCase.TOOL_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_moderation_check_trace_routed( + self, + mock_ee_enabled: MagicMock, + ) -> None: + mock_trace_manager = MagicMock() + context = {"tenant_id": "tenant-123", "app_id": "app-123"} + payload = {"message_id": "msg-123", "moderation_result": {"flagged": False}} + + emit(TelemetryCase.MODERATION_CHECK, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayIntegrationCEEligibility: + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_workflow_run_is_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_message_run_is_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"message_id": "msg-abc", "conversation_id": "conv-123"} + + emit(TelemetryCase.MESSAGE_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_node_execution_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_draft_node_execution_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_execution_data": {}} + + emit(TelemetryCase.DRAFT_NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_prompt_generation_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"operation_type": "generate", "instruction": "test"} + + emit(TelemetryCase.PROMPT_GENERATION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + +class TestIsEnterpriseTelemetryEnabled: + def test_returns_false_when_exporter_import_fails(self) -> None: + with patch.dict(sys.modules, {"enterprise.telemetry.exporter": None}): + result = is_enterprise_telemetry_enabled() + assert result is False + + def test_function_is_callable(self) -> None: + assert callable(is_enterprise_telemetry_enabled) diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index f17927f16b..ac65d0c02b 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,6 +1,7 @@ import json from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig + from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index afea9144c0..f5efb78b61 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -2,12 +2,12 @@ from unittest.mock import MagicMock, patch import pytest import redis +from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.model_manager import LBModelManager from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.model_entities import ModelType @pytest.fixture diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index b19a21d7f4..331166fe63 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -1,6 +1,15 @@ from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormOption, + FormType, + ProviderEntity, +) from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus from core.entities.provider_entities import ( @@ -12,15 +21,6 @@ from core.entities.provider_entities import ( RestrictModel, SystemConfiguration, ) -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormOption, - FormType, - ProviderEntity, -) from models.provider import Provider, ProviderType diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 7f6a50af99..259cb5fdd0 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -2,12 +2,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelSettings from core.provider_manager import ProviderManager -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType from models.provider import LoadBalancingModelConfig, ProviderModelSetting, TenantDefaultModel from models.provider_ids import ModelProviderID diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py index 1ff81f6120..5d744f88c9 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py @@ -6,13 +6,13 @@ from typing import Any from unittest.mock import patch import pytest +from graphon.model_runtime.entities.message_entities import UserPromptMessage from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType -from graphon.model_runtime.entities.message_entities import UserPromptMessage class _BuiltinDummyTool(BuiltinTool): diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py index 9ac280e31a..ee0ce51eec 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -6,6 +6,8 @@ from datetime import date from types import SimpleNamespace import pytest +from graphon.file import FileType +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -27,8 +29,6 @@ from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.errors import ToolInvokeError -from graphon.file.enums import FileType -from graphon.model_runtime.entities.model_entities import ModelPropertyKey def _build_builtin_tool(tool_cls: type[BuiltinTool]) -> BuiltinTool: diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py index b3442636b7..7fcebde3c5 100644 --- a/api/tests/unit_tests/core/tools/test_tool_file_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -12,9 +12,9 @@ from unittest.mock import MagicMock, Mock, patch import httpx import pytest +from graphon.file import FileTransferMethod from core.tools.tool_file_manager import ToolFileManager -from graphon.file import FileTransferMethod def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]: diff --git a/api/tests/unit_tests/core/tools/test_tool_label_manager.py b/api/tests/unit_tests/core/tools/test_tool_label_manager.py index 857f4aa178..8c0e7e9419 100644 --- a/api/tests/unit_tests/core/tools/test_tool_label_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_label_manager.py @@ -38,11 +38,9 @@ def test_tool_label_manager_filter_tool_labels(): def test_tool_label_manager_update_tool_labels_db(): controller = _api_controller("api-1") with patch("core.tools.tool_label_manager.db") as mock_db: - delete_query = mock_db.session.query.return_value.where.return_value - delete_query.delete.return_value = None ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"]) - delete_query.delete.assert_called_once() + mock_db.session.execute.assert_called_once() # only one valid unique label should be inserted. assert mock_db.session.add.call_count == 1 mock_db.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py index 844bc01e29..31b68f0b3f 100644 --- a/api/tests/unit_tests/core/tools/test_tool_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -220,9 +220,7 @@ def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks(): with patch.object(ToolManager, "get_builtin_provider", return_value=controller): with patch("core.helper.credential_utils.check_credential_policy_compliance"): with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( - builtin_provider - ) + mock_db.session.scalar.return_value = builtin_provider encrypter = Mock() encrypter.decrypt.return_value = {"api_key": "secret"} cache = Mock() @@ -274,7 +272,7 @@ def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials( ) refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456) - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider + mock_db.session.scalar.return_value = builtin_provider encrypter = Mock() encrypter.decrypt.return_value = {"token": "old"} encrypter.encrypt.return_value = {"token": "encrypted"} @@ -698,12 +696,10 @@ def test_get_api_provider_controller_returns_controller_and_credentials(): privacy_policy="privacy", custom_disclaimer="disclaimer", ) - db_query = Mock() - db_query.where.return_value.first.return_value = provider controller = Mock() with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value = db_query + mock_db.session.scalar.return_value = provider with patch( "core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller ) as mock_from_db: @@ -730,12 +726,10 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels(): privacy_policy="privacy", custom_disclaimer="disclaimer", ) - db_query = Mock() - db_query.where.return_value.first.return_value = provider controller = Mock() with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value = db_query + mock_db.session.scalar.return_value = provider with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller): encrypter = Mock() encrypter.decrypt.return_value = {"api_key_value": "secret"} @@ -750,7 +744,7 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels(): def test_get_api_provider_controller_not_found_raises(): with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"): ToolManager.get_api_provider_controller("tenant-1", "missing") @@ -809,14 +803,14 @@ def test_generate_tool_icon_urls_for_workflow_and_api(): workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}') api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}') with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider] + mock_db.session.scalar.side_effect = [workflow_provider, api_provider] assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"} assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"} def test_generate_tool_icon_urls_missing_workflow_and_api_use_default(): with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525" assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525" diff --git a/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py index 4ce73272bf..a93624123e 100644 --- a/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py +++ b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py @@ -263,7 +263,7 @@ def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources(): ) db_session = Mock() db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high] - db_session.query.return_value.filter_by.return_value.first.return_value = dataset + db_session.get.return_value = dataset tool = SingleDatasetRetrieverTool( tenant_id="tenant-1", @@ -444,7 +444,7 @@ def test_multi_dataset_retriever_run_orders_segments_and_returns_resources(): ) db_session = Mock() db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1] - db_session.query.return_value.filter_by.return_value.first.side_effect = [ + db_session.get.side_effect = [ SimpleNamespace(id="dataset-2", name="Dataset Two"), SimpleNamespace(id="dataset-1", name="Dataset One"), ] diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py index a4a563a4a1..52f262e1cf 100644 --- a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -13,8 +13,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest - -from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, @@ -24,6 +22,8 @@ from graphon.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) +from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils + def _mock_model_instance(*, schema: dict | None = None) -> SimpleNamespace: model_type_instance = Mock() diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py index 43f3fbd5c9..0e3a7e623a 100644 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -1,9 +1,9 @@ import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils -from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_ensure_no_human_input_nodes_passes_for_non_human_input(): diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py index b147d7fcdb..2607861b59 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -4,6 +4,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( @@ -13,7 +14,6 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from graphon.variables.input_entities import VariableEntity, VariableEntityType def _controller() -> WorkflowToolProviderController: diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 72a73dd936..c20edd7400 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -11,6 +11,7 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -24,7 +25,6 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool -from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType class StubScalars: diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py index ee7a3d9c96..78622b78b6 100644 --- a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -11,6 +11,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes, NodeType from core.plugin.entities.request import TriggerInvokeEventResponse from core.trigger.constants import ( @@ -26,7 +27,6 @@ from core.trigger.debug.event_selectors import ( select_trigger_debug_events, ) from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent -from graphon.enums import BuiltinNodeTypes, NodeType from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 72052c8c05..7406b88270 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -2,11 +2,6 @@ import dataclasses import orjson import pytest -from pydantic import BaseModel - -from core.helper import encrypter -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool from graphon.file import File, FileTransferMethod, FileType from graphon.runtime import VariablePool from graphon.variables.segment_group import SegmentGroup @@ -47,6 +42,11 @@ from graphon.variables.variables import ( StringVariable, Variable, ) +from pydantic import BaseModel + +from core.helper import encrypter +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool def _build_variable_pool( diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index d4e862220a..37ecd2890b 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,5 +1,4 @@ import pytest - from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import StringSegment from graphon.variables.types import ArrayValidation, SegmentType diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 14f9b2991d..09254e17a3 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -9,9 +9,7 @@ from dataclasses import dataclass from typing import Any import pytest - -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File +from graphon.file import File, FileTransferMethod, FileType from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ( ArrayFileSegment, diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index dae5e1ce98..75b01bf42e 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -1,6 +1,4 @@ import pytest -from pydantic import ValidationError - from graphon.variables import ( ArrayFileVariable, ArrayVariable, @@ -12,6 +10,7 @@ from graphon.variables import ( StringVariable, ) from graphon.variables.variables import VariableBase +from pydantic import ValidationError def test_frozen_variables(): diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py deleted file mode 100644 index ef5500b72f..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ /dev/null @@ -1,307 +0,0 @@ -import json -from time import time -from unittest.mock import MagicMock, patch - -import pytest - -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from graphon.variables.variables import StringVariable - - -class StubCoordinator: - def __init__(self) -> None: - self.state = "initial" - - def dumps(self) -> str: - return json.dumps({"state": self.state}) - - def loads(self, data: str) -> None: - payload = json.loads(data) - self.state = payload["state"] - - -class TestGraphRuntimeState: - def test_execution_context_defaults_to_empty_context(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - - with state.execution_context: - assert state.execution_context is not None - - state.execution_context = None - - with state.execution_context: - assert state.execution_context is not None - - def test_property_getters_and_setters(self): - # FIXME(-LAN-): Mock VariablePool if needed - variable_pool = VariablePool() - start_time = time() - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=start_time) - - # Test variable_pool property (read-only) - assert state.variable_pool == variable_pool - - # Test start_at property - assert state.start_at == start_time - new_time = time() + 100 - state.start_at = new_time - assert state.start_at == new_time - - # Test total_tokens property - assert state.total_tokens == 0 - state.total_tokens = 100 - assert state.total_tokens == 100 - - # Test node_run_steps property - assert state.node_run_steps == 0 - state.node_run_steps = 5 - assert state.node_run_steps == 5 - - def test_outputs_immutability(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test that getting outputs returns a copy - outputs1 = state.outputs - outputs2 = state.outputs - assert outputs1 == outputs2 - assert outputs1 is not outputs2 # Different objects - - # Test that modifying retrieved outputs doesn't affect internal state - outputs = state.outputs - outputs["test"] = "value" - assert "test" not in state.outputs - - # Test set_output method - state.set_output("key1", "value1") - assert state.get_output("key1") == "value1" - - # Test update_outputs method - state.update_outputs({"key2": "value2", "key3": "value3"}) - assert state.get_output("key2") == "value2" - assert state.get_output("key3") == "value3" - - def test_llm_usage_immutability(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test that getting llm_usage returns a copy - usage1 = state.llm_usage - usage2 = state.llm_usage - assert usage1 is not usage2 # Different objects - - def test_type_validation(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test total_tokens validation - with pytest.raises(ValueError): - state.total_tokens = -1 - - # Test node_run_steps validation - with pytest.raises(ValueError): - state.node_run_steps = -1 - - def test_helper_methods(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test increment_node_run_steps - initial_steps = state.node_run_steps - state.increment_node_run_steps() - assert state.node_run_steps == initial_steps + 1 - - # Test add_tokens - initial_tokens = state.total_tokens - state.add_tokens(50) - assert state.total_tokens == initial_tokens + 50 - - # Test add_tokens validation - with pytest.raises(ValueError): - state.add_tokens(-1) - - def test_ready_queue_default_instantiation(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - - queue = state.ready_queue - - from graphon.graph_engine.ready_queue import InMemoryReadyQueue - - assert isinstance(queue, InMemoryReadyQueue) - - def test_graph_execution_lazy_instantiation(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - - execution = state.graph_execution - - from graphon.graph_engine.domain.graph_execution import GraphExecution - - assert isinstance(execution, GraphExecution) - assert execution.workflow_id == "" - assert state.graph_execution is execution - - def test_response_coordinator_configuration(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - with pytest.raises(ValueError): - _ = state.response_coordinator - - mock_graph = MagicMock() - with patch( - "graphon.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True - ) as coordinator_cls: - coordinator_instance = coordinator_cls.return_value - state.configure(graph=mock_graph) - - assert state.response_coordinator is coordinator_instance - coordinator_cls.assert_called_once_with(variable_pool=variable_pool, graph=mock_graph) - - # Configure again with same graph should be idempotent - state.configure(graph=mock_graph) - - other_graph = MagicMock() - with pytest.raises(ValueError): - state.attach_graph(other_graph) - - def test_read_only_wrapper_exposes_additional_state(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - state.configure() - - wrapper = ReadOnlyGraphRuntimeStateWrapper(state) - - assert wrapper.ready_queue_size == 0 - assert wrapper.exceptions_count == 0 - - def test_read_only_wrapper_serializes_runtime_state(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - state.total_tokens = 5 - state.set_output("result", {"success": True}) - state.ready_queue.put("node-1") - - wrapper = ReadOnlyGraphRuntimeStateWrapper(state) - - wrapper_snapshot = json.loads(wrapper.dumps()) - state_snapshot = json.loads(state.dumps()) - - assert wrapper_snapshot == state_snapshot - - def test_dumps_and_loads_roundtrip_with_response_coordinator(self): - variable_pool = VariablePool() - variable_pool.add(("node1", "value"), "payload") - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - state.total_tokens = 10 - state.node_run_steps = 3 - state.set_output("final", {"result": True}) - usage = LLMUsage.from_metadata( - { - "prompt_tokens": 2, - "completion_tokens": 3, - "total_tokens": 5, - "total_price": "1.23", - "currency": "USD", - "latency": 0.5, - } - ) - state.llm_usage = usage - state.ready_queue.put("node-A") - - graph_execution = state.graph_execution - graph_execution.workflow_id = "wf-123" - graph_execution.exceptions_count = 4 - graph_execution.started = True - - mock_graph = MagicMock() - stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub, autospec=True): - state.attach_graph(mock_graph) - - stub.state = "configured" - - snapshot = state.dumps() - - restored = GraphRuntimeState.from_snapshot(snapshot) - - assert restored.total_tokens == 10 - assert restored.node_run_steps == 3 - assert restored.get_output("final") == {"result": True} - assert restored.llm_usage.total_tokens == usage.total_tokens - assert restored.ready_queue.qsize() == 1 - assert restored.ready_queue.get(timeout=0.01) == "node-A" - - restored_segment = restored.variable_pool.get(("node1", "value")) - assert restored_segment is not None - assert restored_segment.value == "payload" - - restored_execution = restored.graph_execution - assert restored_execution.workflow_id == "wf-123" - assert restored_execution.exceptions_count == 4 - assert restored_execution.started is True - - new_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True): - restored.attach_graph(mock_graph) - - assert new_stub.state == "configured" - - def test_loads_rehydrates_existing_instance(self): - variable_pool = VariablePool() - variable_pool.add(("node", "key"), "value") - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - state.total_tokens = 7 - state.node_run_steps = 2 - state.set_output("foo", "bar") - state.ready_queue.put("node-1") - - execution = state.graph_execution - execution.workflow_id = "wf-456" - execution.started = True - - mock_graph = MagicMock() - original_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub, autospec=True): - state.attach_graph(mock_graph) - - original_stub.state = "configured" - snapshot = state.dumps() - - new_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True): - restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) - restored.attach_graph(mock_graph) - restored.loads(snapshot) - - assert restored.total_tokens == 7 - assert restored.node_run_steps == 2 - assert restored.get_output("foo") == "bar" - assert restored.ready_queue.qsize() == 1 - assert restored.ready_queue.get(timeout=0.01) == "node-1" - - restored_segment = restored.variable_pool.get(("node", "key")) - assert restored_segment is not None - assert restored_segment.value == "value" - - restored_execution = restored.graph_execution - assert restored_execution.workflow_id == "wf-456" - assert restored_execution.started is True - - assert new_stub.state == "configured" - - def test_snapshot_restore_preserves_updated_conversation_variable(self): - variable_pool = VariablePool( - conversation_variables=[StringVariable(name="session_name", value="before")], - ) - variable_pool.add((CONVERSATION_VARIABLE_NODE_ID, "session_name"), "after") - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - snapshot = state.dumps() - restored = GraphRuntimeState.from_snapshot(snapshot) - - restored_value = restored.variable_pool.get((CONVERSATION_VARIABLE_NODE_ID, "session_name")) - assert restored_value is not None - assert restored_value.value == "after" diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py deleted file mode 100644 index 856ec959b7..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Tests for PauseReason discriminated union serialization/deserialization. -""" - -import pytest -from pydantic import BaseModel, ValidationError - -from graphon.entities.pause_reason import ( - HumanInputRequired, - PauseReason, - SchedulingPause, -) - - -class _Holder(BaseModel): - """Helper model that embeds PauseReason for union tests.""" - - reason: PauseReason - - -class TestPauseReasonDiscriminator: - """Test suite for PauseReason union discriminator.""" - - @pytest.mark.parametrize( - ("dict_value", "expected"), - [ - pytest.param( - { - "reason": { - "TYPE": "human_input_required", - "form_id": "form_id", - "form_content": "form_content", - "node_id": "node_id", - "node_title": "node_title", - }, - }, - HumanInputRequired( - form_id="form_id", - form_content="form_content", - node_id="node_id", - node_title="node_title", - ), - id="HumanInputRequired", - ), - pytest.param( - { - "reason": { - "TYPE": "scheduled_pause", - "message": "Hold on", - } - }, - SchedulingPause(message="Hold on"), - id="SchedulingPause", - ), - ], - ) - def test_model_validate(self, dict_value, expected): - """Ensure scheduled pause payloads with lowercase TYPE deserialize.""" - holder = _Holder.model_validate(dict_value) - - assert type(holder.reason) == type(expected) - assert holder.reason == expected - - @pytest.mark.parametrize( - "reason", - [ - HumanInputRequired( - form_id="form_id", - form_content="form_content", - node_id="node_id", - node_title="node_title", - ), - SchedulingPause(message="Hold on"), - ], - ids=lambda x: type(x).__name__, - ) - def test_model_construct(self, reason): - holder = _Holder(reason=reason) - assert holder.reason == reason - - def test_model_construct_with_invalid_type(self): - with pytest.raises(ValidationError): - holder = _Holder(reason=object()) # type: ignore - - def test_unknown_type_fails_validation(self): - """Unknown TYPE values should raise a validation error.""" - with pytest.raises(ValidationError): - _Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}}) diff --git a/api/tests/unit_tests/core/workflow/entities/test_template.py b/api/tests/unit_tests/core/workflow/entities/test_template.py deleted file mode 100644 index e8304b9bcd..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_template.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Tests for template module.""" - -from graphon.nodes.base.template import Template, TextSegment, VariableSegment - - -class TestTemplate: - """Test Template class functionality.""" - - def test_from_answer_template_simple(self): - """Test parsing a simple answer template.""" - template_str = "Hello, {{#node1.name#}}!" - template = Template.from_answer_template(template_str) - - assert len(template.segments) == 3 - assert isinstance(template.segments[0], TextSegment) - assert template.segments[0].text == "Hello, " - assert isinstance(template.segments[1], VariableSegment) - assert template.segments[1].selector == ["node1", "name"] - assert isinstance(template.segments[2], TextSegment) - assert template.segments[2].text == "!" - - def test_from_answer_template_multiple_vars(self): - """Test parsing an answer template with multiple variables.""" - template_str = "Hello {{#node1.name#}}, your age is {{#node2.age#}}." - template = Template.from_answer_template(template_str) - - assert len(template.segments) == 5 - assert isinstance(template.segments[0], TextSegment) - assert template.segments[0].text == "Hello " - assert isinstance(template.segments[1], VariableSegment) - assert template.segments[1].selector == ["node1", "name"] - assert isinstance(template.segments[2], TextSegment) - assert template.segments[2].text == ", your age is " - assert isinstance(template.segments[3], VariableSegment) - assert template.segments[3].selector == ["node2", "age"] - assert isinstance(template.segments[4], TextSegment) - assert template.segments[4].text == "." - - def test_from_answer_template_no_vars(self): - """Test parsing an answer template with no variables.""" - template_str = "Hello, world!" - template = Template.from_answer_template(template_str) - - assert len(template.segments) == 1 - assert isinstance(template.segments[0], TextSegment) - assert template.segments[0].text == "Hello, world!" - - def test_from_end_outputs_single(self): - """Test creating template from End node outputs with single variable.""" - outputs_config = [{"variable": "text", "value_selector": ["node1", "text"]}] - template = Template.from_end_outputs(outputs_config) - - assert len(template.segments) == 1 - assert isinstance(template.segments[0], VariableSegment) - assert template.segments[0].selector == ["node1", "text"] - - def test_from_end_outputs_multiple(self): - """Test creating template from End node outputs with multiple variables.""" - outputs_config = [ - {"variable": "text", "value_selector": ["node1", "text"]}, - {"variable": "result", "value_selector": ["node2", "result"]}, - ] - template = Template.from_end_outputs(outputs_config) - - assert len(template.segments) == 3 - assert isinstance(template.segments[0], VariableSegment) - assert template.segments[0].selector == ["node1", "text"] - assert template.segments[0].variable_name == "text" - assert isinstance(template.segments[1], TextSegment) - assert template.segments[1].text == "\n" - assert isinstance(template.segments[2], VariableSegment) - assert template.segments[2].selector == ["node2", "result"] - assert template.segments[2].variable_name == "result" - - def test_from_end_outputs_empty(self): - """Test creating template from empty End node outputs.""" - outputs_config = [] - template = Template.from_end_outputs(outputs_config) - - assert len(template.segments) == 0 - - def test_template_str_representation(self): - """Test string representation of template.""" - template_str = "Hello, {{#node1.name#}}!" - template = Template.from_answer_template(template_str) - - assert str(template) == template_str diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py deleted file mode 100644 index 7e08751683..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py +++ /dev/null @@ -1,136 +0,0 @@ -from graphon.runtime import VariablePool -from graphon.variables.segments import ( - BooleanSegment, - IntegerSegment, - NoneSegment, - StringSegment, -) - - -class TestVariablePoolGetAndNestedAttribute: - # - # _get_nested_attribute tests - # - def test__get_nested_attribute_existing_key(self): - pool = VariablePool.empty() - obj = {"a": 123} - segment = pool._get_nested_attribute(obj, "a") - assert segment is not None - assert segment.value == 123 - - def test__get_nested_attribute_missing_key(self): - pool = VariablePool.empty() - obj = {"a": 123} - segment = pool._get_nested_attribute(obj, "b") - assert segment is None - - def test__get_nested_attribute_non_dict(self): - pool = VariablePool.empty() - obj = ["not", "a", "dict"] - segment = pool._get_nested_attribute(obj, "a") - assert segment is None - - def test__get_nested_attribute_with_none_value(self): - pool = VariablePool.empty() - obj = {"a": None} - segment = pool._get_nested_attribute(obj, "a") - assert segment is not None - assert isinstance(segment, NoneSegment) - - def test__get_nested_attribute_with_empty_string(self): - pool = VariablePool.empty() - obj = {"a": ""} - segment = pool._get_nested_attribute(obj, "a") - assert segment is not None - assert isinstance(segment, StringSegment) - assert segment.value == "" - - # - # get tests - # - def test_get_simple_variable(self): - pool = VariablePool.empty() - pool.add(("node1", "var1"), "value1") - segment = pool.get(("node1", "var1")) - assert segment is not None - assert segment.value == "value1" - - def test_get_missing_variable(self): - pool = VariablePool.empty() - result = pool.get(("node1", "unknown")) - assert result is None - - def test_get_with_too_short_selector(self): - pool = VariablePool.empty() - result = pool.get(("only_node",)) - assert result is None - - def test_get_nested_object_attribute(self): - pool = VariablePool.empty() - obj_value = {"inner": "hello"} - pool.add(("node1", "obj"), obj_value) - - # simulate selector with nested attr - segment = pool.get(("node1", "obj", "inner")) - assert segment is not None - assert segment.value == "hello" - - def test_get_nested_object_missing_attribute(self): - pool = VariablePool.empty() - obj_value = {"inner": "hello"} - pool.add(("node1", "obj"), obj_value) - - result = pool.get(("node1", "obj", "not_exist")) - assert result is None - - def test_get_nested_object_attribute_with_falsy_values(self): - pool = VariablePool.empty() - obj_value = { - "inner_none": None, - "inner_empty": "", - "inner_zero": 0, - "inner_false": False, - } - pool.add(("node1", "obj"), obj_value) - - segment_none = pool.get(("node1", "obj", "inner_none")) - assert segment_none is not None - assert isinstance(segment_none, NoneSegment) - - segment_empty = pool.get(("node1", "obj", "inner_empty")) - assert segment_empty is not None - assert isinstance(segment_empty, StringSegment) - assert segment_empty.value == "" - - segment_zero = pool.get(("node1", "obj", "inner_zero")) - assert segment_zero is not None - assert isinstance(segment_zero, IntegerSegment) - assert segment_zero.value == 0 - - segment_false = pool.get(("node1", "obj", "inner_false")) - assert segment_false is not None - assert isinstance(segment_false, BooleanSegment) - assert segment_false.value is False - - -class TestVariablePoolGetNotModifyVariableDictionary: - _NODE_ID = "start" - _VAR_NAME = "name" - - def test_convert_to_template_should_not_introduce_extra_keys(self): - pool = VariablePool.empty() - pool.add([self._NODE_ID, self._VAR_NAME], 0) - pool.convert_template("The start.name is {{#start.name#}}") - assert "The start" not in pool.variable_dictionary - - def test_get_should_not_modify_variable_dictionary(self): - pool = VariablePool.empty() - pool.get([self._NODE_ID, self._VAR_NAME]) - assert len(pool.variable_dictionary) == 0 - assert "start" not in pool.variable_dictionary - - pool = VariablePool.empty() - pool.add([self._NODE_ID, self._VAR_NAME], "Joe") - pool.get([self._NODE_ID, "count"]) - start_subdict = pool.variable_dictionary[self._NODE_ID] - assert "count" not in start_subdict diff --git a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py deleted file mode 100644 index 5e697f22f3..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py +++ /dev/null @@ -1,225 +0,0 @@ -""" -Unit tests for WorkflowNodeExecution domain model, focusing on process_data truncation functionality. -""" - -from dataclasses import dataclass -from datetime import datetime -from typing import Any - -import pytest - -from graphon.entities.workflow_node_execution import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes - - -class TestWorkflowNodeExecutionProcessDataTruncation: - """Test process_data truncation functionality in WorkflowNodeExecution domain model.""" - - def create_workflow_node_execution( - self, - process_data: dict[str, Any] | None = None, - ) -> WorkflowNodeExecution: - """Create a WorkflowNodeExecution instance for testing.""" - return WorkflowNodeExecution( - id="test-execution-id", - workflow_id="test-workflow-id", - index=1, - node_id="test-node-id", - node_type=BuiltinNodeTypes.LLM, - title="Test Node", - process_data=process_data, - created_at=datetime.now(), - ) - - def test_initial_process_data_truncated_state(self): - """Test that process_data_truncated returns False initially.""" - execution = self.create_workflow_node_execution() - - assert execution.process_data_truncated is False - assert execution.get_truncated_process_data() is None - - def test_set_and_get_truncated_process_data(self): - """Test setting and getting truncated process_data.""" - execution = self.create_workflow_node_execution() - test_truncated_data = {"truncated": True, "key": "value"} - - execution.set_truncated_process_data(test_truncated_data) - - assert execution.process_data_truncated is True - assert execution.get_truncated_process_data() == test_truncated_data - - def test_set_truncated_process_data_to_none(self): - """Test setting truncated process_data to None.""" - execution = self.create_workflow_node_execution() - - # First set some data - execution.set_truncated_process_data({"key": "value"}) - assert execution.process_data_truncated is True - - # Then set to None - execution.set_truncated_process_data(None) - assert execution.process_data_truncated is False - assert execution.get_truncated_process_data() is None - - def test_get_response_process_data_with_no_truncation(self): - """Test get_response_process_data when no truncation is set.""" - original_data = {"original": True, "data": "value"} - execution = self.create_workflow_node_execution(process_data=original_data) - - response_data = execution.get_response_process_data() - - assert response_data == original_data - assert execution.process_data_truncated is False - - def test_get_response_process_data_with_truncation(self): - """Test get_response_process_data when truncation is set.""" - original_data = {"original": True, "large_data": "x" * 10000} - truncated_data = {"original": True, "large_data": "[TRUNCATED]"} - - execution = self.create_workflow_node_execution(process_data=original_data) - execution.set_truncated_process_data(truncated_data) - - response_data = execution.get_response_process_data() - - # Should return truncated data, not original - assert response_data == truncated_data - assert response_data != original_data - assert execution.process_data_truncated is True - - def test_get_response_process_data_with_none_process_data(self): - """Test get_response_process_data when process_data is None.""" - execution = self.create_workflow_node_execution(process_data=None) - - response_data = execution.get_response_process_data() - - assert response_data is None - assert execution.process_data_truncated is False - - def test_consistency_with_inputs_outputs_pattern(self): - """Test that process_data truncation follows the same pattern as inputs/outputs.""" - execution = self.create_workflow_node_execution() - - # Test that all truncation methods exist and behave consistently - test_data = {"test": "data"} - - # Test inputs truncation - execution.set_truncated_inputs(test_data) - assert execution.inputs_truncated is True - assert execution.get_truncated_inputs() == test_data - - # Test outputs truncation - execution.set_truncated_outputs(test_data) - assert execution.outputs_truncated is True - assert execution.get_truncated_outputs() == test_data - - # Test process_data truncation - execution.set_truncated_process_data(test_data) - assert execution.process_data_truncated is True - assert execution.get_truncated_process_data() == test_data - - @pytest.mark.parametrize( - "test_data", - [ - {"simple": "value"}, - {"nested": {"key": "value"}}, - {"list": [1, 2, 3]}, - {"mixed": {"string": "value", "number": 42, "list": [1, 2]}}, - {}, # empty dict - ], - ) - def test_truncated_process_data_with_various_data_types(self, test_data): - """Test that truncated process_data works with various data types.""" - execution = self.create_workflow_node_execution() - - execution.set_truncated_process_data(test_data) - - assert execution.process_data_truncated is True - assert execution.get_truncated_process_data() == test_data - assert execution.get_response_process_data() == test_data - - -@dataclass -class ProcessDataScenario: - """Test scenario data for process_data functionality.""" - - name: str - original_data: dict[str, Any] | None - truncated_data: dict[str, Any] | None - expected_truncated_flag: bool - expected_response_data: dict[str, Any] | None - - -class TestWorkflowNodeExecutionProcessDataScenarios: - """Test various scenarios for process_data handling.""" - - def get_process_data_scenarios(self) -> list[ProcessDataScenario]: - """Create test scenarios for process_data functionality.""" - return [ - ProcessDataScenario( - name="no_process_data", - original_data=None, - truncated_data=None, - expected_truncated_flag=False, - expected_response_data=None, - ), - ProcessDataScenario( - name="process_data_without_truncation", - original_data={"small": "data"}, - truncated_data=None, - expected_truncated_flag=False, - expected_response_data={"small": "data"}, - ), - ProcessDataScenario( - name="process_data_with_truncation", - original_data={"large": "x" * 10000, "metadata": "info"}, - truncated_data={"large": "[TRUNCATED]", "metadata": "info"}, - expected_truncated_flag=True, - expected_response_data={"large": "[TRUNCATED]", "metadata": "info"}, - ), - ProcessDataScenario( - name="empty_process_data", - original_data={}, - truncated_data=None, - expected_truncated_flag=False, - expected_response_data={}, - ), - ProcessDataScenario( - name="complex_nested_data_with_truncation", - original_data={ - "config": {"setting": "value"}, - "logs": ["log1", "log2"] * 1000, # Large list - "status": "running", - }, - truncated_data={"config": {"setting": "value"}, "logs": "[TRUNCATED: 2000 items]", "status": "running"}, - expected_truncated_flag=True, - expected_response_data={ - "config": {"setting": "value"}, - "logs": "[TRUNCATED: 2000 items]", - "status": "running", - }, - ), - ] - - @pytest.mark.parametrize( - "scenario", - get_process_data_scenarios(None), - ids=[scenario.name for scenario in get_process_data_scenarios(None)], - ) - def test_process_data_scenarios(self, scenario: ProcessDataScenario): - """Test various process_data scenarios.""" - execution = WorkflowNodeExecution( - id="test-execution-id", - workflow_id="test-workflow-id", - index=1, - node_id="test-node-id", - node_type=BuiltinNodeTypes.LLM, - title="Test Node", - process_data=scenario.original_data, - created_at=datetime.now(), - ) - - if scenario.truncated_data is not None: - execution.set_truncated_process_data(scenario.truncated_data) - - assert execution.process_data_truncated == scenario.expected_truncated_flag - assert execution.get_response_process_data() == scenario.expected_response_data diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph.py b/api/tests/unit_tests/core/workflow/graph/test_graph.py deleted file mode 100644 index b138a7dfdc..0000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph.py +++ /dev/null @@ -1,281 +0,0 @@ -"""Unit tests for Graph class methods.""" - -from unittest.mock import Mock - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState -from graphon.graph.edge import Edge -from graphon.graph.graph import Graph -from graphon.nodes.base.node import Node - - -def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node: - """Create a mock node for testing.""" - node = Mock(spec=Node) - node.id = node_id - node.execution_type = execution_type - node.state = state - node.node_type = BuiltinNodeTypes.START - return node - - -class TestMarkInactiveRootBranches: - """Test cases for _mark_inactive_root_branches method.""" - - def test_single_root_no_marking(self): - """Test that single root graph doesn't mark anything as skipped.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - } - - in_edges = {"child1": ["edge1"]} - out_edges = {"root1": ["edge1"]} - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["child1"].state == NodeState.UNKNOWN - assert edges["edge1"].state == NodeState.UNKNOWN - - def test_multiple_roots_mark_inactive(self): - """Test marking inactive root branches with multiple root nodes.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), - } - - in_edges = {"child1": ["edge1"], "child2": ["edge2"]} - out_edges = {"root1": ["edge1"], "root2": ["edge2"]} - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["child1"].state == NodeState.UNKNOWN - assert nodes["child2"].state == NodeState.SKIPPED - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - - def test_shared_downstream_node(self): - """Test that shared downstream nodes are not skipped if at least one path is active.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - "shared": create_mock_node("shared", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), - "edge3": Edge(id="edge3", tail="child1", head="shared", source_handle="source"), - "edge4": Edge(id="edge4", tail="child2", head="shared", source_handle="source"), - } - - in_edges = { - "child1": ["edge1"], - "child2": ["edge2"], - "shared": ["edge3", "edge4"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "child1": ["edge3"], - "child2": ["edge4"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["child1"].state == NodeState.UNKNOWN - assert nodes["child2"].state == NodeState.SKIPPED - assert nodes["shared"].state == NodeState.UNKNOWN # Not skipped because edge3 is active - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - assert edges["edge3"].state == NodeState.UNKNOWN - assert edges["edge4"].state == NodeState.SKIPPED - - def test_deep_branch_marking(self): - """Test marking deep branches with multiple levels.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "level1_a": create_mock_node("level1_a", NodeExecutionType.EXECUTABLE), - "level1_b": create_mock_node("level1_b", NodeExecutionType.EXECUTABLE), - "level2_a": create_mock_node("level2_a", NodeExecutionType.EXECUTABLE), - "level2_b": create_mock_node("level2_b", NodeExecutionType.EXECUTABLE), - "level3": create_mock_node("level3", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="level1_a", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="level1_b", source_handle="source"), - "edge3": Edge(id="edge3", tail="level1_a", head="level2_a", source_handle="source"), - "edge4": Edge(id="edge4", tail="level1_b", head="level2_b", source_handle="source"), - "edge5": Edge(id="edge5", tail="level2_b", head="level3", source_handle="source"), - } - - in_edges = { - "level1_a": ["edge1"], - "level1_b": ["edge2"], - "level2_a": ["edge3"], - "level2_b": ["edge4"], - "level3": ["edge5"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "level1_a": ["edge3"], - "level1_b": ["edge4"], - "level2_b": ["edge5"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["level1_a"].state == NodeState.UNKNOWN - assert nodes["level1_b"].state == NodeState.SKIPPED - assert nodes["level2_a"].state == NodeState.UNKNOWN - assert nodes["level2_b"].state == NodeState.SKIPPED - assert nodes["level3"].state == NodeState.SKIPPED - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - assert edges["edge3"].state == NodeState.UNKNOWN - assert edges["edge4"].state == NodeState.SKIPPED - assert edges["edge5"].state == NodeState.SKIPPED - - def test_non_root_execution_type(self): - """Test that nodes with non-ROOT execution type are not treated as root nodes.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "non_root": create_mock_node("non_root", NodeExecutionType.EXECUTABLE), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="non_root", head="child2", source_handle="source"), - } - - in_edges = {"child1": ["edge1"], "child2": ["edge2"]} - out_edges = {"root1": ["edge1"], "non_root": ["edge2"]} - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["non_root"].state == NodeState.UNKNOWN # Not marked as skipped - assert nodes["child1"].state == NodeState.UNKNOWN - assert nodes["child2"].state == NodeState.UNKNOWN - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.UNKNOWN - - def test_empty_graph(self): - """Test handling of empty graph structures.""" - nodes = {} - edges = {} - in_edges = {} - out_edges = {} - - # Should not raise any errors - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "non_existent") - - def test_three_roots_mark_two_inactive(self): - """Test with three root nodes where two should be marked inactive.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "root3": create_mock_node("root3", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - "child3": create_mock_node("child3", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), - "edge3": Edge(id="edge3", tail="root3", head="child3", source_handle="source"), - } - - in_edges = { - "child1": ["edge1"], - "child2": ["edge2"], - "child3": ["edge3"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "root3": ["edge3"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root2") - - assert nodes["root1"].state == NodeState.SKIPPED - assert nodes["root2"].state == NodeState.UNKNOWN # Active root - assert nodes["root3"].state == NodeState.SKIPPED - assert nodes["child1"].state == NodeState.SKIPPED - assert nodes["child2"].state == NodeState.UNKNOWN - assert nodes["child3"].state == NodeState.SKIPPED - assert edges["edge1"].state == NodeState.SKIPPED - assert edges["edge2"].state == NodeState.UNKNOWN - assert edges["edge3"].state == NodeState.SKIPPED - - def test_convergent_paths(self): - """Test convergent paths where multiple inactive branches lead to same node.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "root3": create_mock_node("root3", NodeExecutionType.ROOT), - "mid1": create_mock_node("mid1", NodeExecutionType.EXECUTABLE), - "mid2": create_mock_node("mid2", NodeExecutionType.EXECUTABLE), - "convergent": create_mock_node("convergent", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="mid1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="mid2", source_handle="source"), - "edge3": Edge(id="edge3", tail="root3", head="convergent", source_handle="source"), - "edge4": Edge(id="edge4", tail="mid1", head="convergent", source_handle="source"), - "edge5": Edge(id="edge5", tail="mid2", head="convergent", source_handle="source"), - } - - in_edges = { - "mid1": ["edge1"], - "mid2": ["edge2"], - "convergent": ["edge3", "edge4", "edge5"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "root3": ["edge3"], - "mid1": ["edge4"], - "mid2": ["edge5"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["root3"].state == NodeState.SKIPPED - assert nodes["mid1"].state == NodeState.UNKNOWN - assert nodes["mid2"].state == NodeState.SKIPPED - assert nodes["convergent"].state == NodeState.UNKNOWN # Not skipped due to active path from root1 - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - assert edges["edge3"].state == NodeState.SKIPPED - assert edges["edge4"].state == NodeState.UNKNOWN - assert edges["edge5"].state == NodeState.SKIPPED diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py deleted file mode 100644 index f3eaa1d686..0000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.graph import Graph -from graphon.nodes.base.node import Node - - -def _make_node(node_id: str, node_type: NodeType = BuiltinNodeTypes.START) -> Node: - node = MagicMock(spec=Node) - node.id = node_id - node.node_type = node_type - node.execution_type = None # attribute not used in builder path - return node - - -def test_graph_builder_creates_linear_graph(): - builder = Graph.new() - root = _make_node("root", BuiltinNodeTypes.START) - mid = _make_node("mid", BuiltinNodeTypes.LLM) - end = _make_node("end", BuiltinNodeTypes.END) - - graph = builder.add_root(root).add_node(mid).add_node(end).build() - - assert graph.root_node is root - assert graph.nodes == {"root": root, "mid": mid, "end": end} - assert len(graph.edges) == 2 - first_edge = next(iter(graph.edges.values())) - assert first_edge.tail == "root" - assert first_edge.head == "mid" - assert graph.out_edges["mid"] == [edge_id for edge_id, edge in graph.edges.items() if edge.tail == "mid"] - - -def test_graph_builder_supports_custom_predecessor(): - builder = Graph.new() - root = _make_node("root") - branch = _make_node("branch") - other = _make_node("other") - - graph = builder.add_root(root).add_node(branch).add_node(other, from_node_id="root").build() - - outgoing_root = graph.out_edges["root"] - assert len(outgoing_root) == 2 - edge_targets = {graph.edges[eid].head for eid in outgoing_root} - assert edge_targets == {"branch", "other"} - - -def test_graph_builder_validates_usage(): - builder = Graph.new() - node = _make_node("node") - - with pytest.raises(ValueError, match="Root node"): - builder.add_node(node) - - builder.add_root(node) - duplicate = _make_node("node") - with pytest.raises(ValueError, match="Duplicate"): - builder.add_node(duplicate) diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py deleted file mode 100644 index 3620a20e56..0000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py +++ /dev/null @@ -1,118 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import pytest - -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import default_system_variables -from graphon.graph import Graph -from graphon.graph.validation import GraphValidationError -from graphon.nodes import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState, VariablePool -from tests.workflow_test_utils import build_test_graph_init_params - - -def _build_iteration_graph(node_id: str) -> dict[str, Any]: - return { - "nodes": [ - { - "id": node_id, - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": [node_id, "output"], - }, - } - ], - "edges": [], - } - - -def _build_loop_graph(node_id: str) -> dict[str, Any]: - return { - "nodes": [ - { - "id": node_id, - "data": { - "type": "loop", - "title": "Loop", - "loop_count": 1, - "break_conditions": [], - "logical_operator": "and", - "loop_variables": [], - "outputs": {}, - }, - } - ], - "edges": [], - } - - -def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=default_system_variables(), - user_inputs={}, - environment_variables=[], - ), - start_at=0.0, - ) - return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) - - -def test_iteration_root_requires_skip_validation(): - node_id = "iteration-node" - graph_config = _build_iteration_graph(node_id) - node_factory = _make_factory(graph_config) - - with pytest.raises(GraphValidationError): - Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - ) - - graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - skip_validation=True, - ) - - assert graph.root_node.id == node_id - assert graph.root_node.node_type == BuiltinNodeTypes.ITERATION - - -def test_loop_root_requires_skip_validation(): - node_id = "loop-node" - graph_config = _build_loop_graph(node_id) - node_factory = _make_factory(graph_config) - - with pytest.raises(GraphValidationError): - Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - ) - - graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - skip_validation=True, - ) - - assert graph.root_node.id == node_id - assert graph.root_node.node_type == BuiltinNodeTypes.LOOP diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py deleted file mode 100644 index bfd0b48392..0000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ /dev/null @@ -1,219 +0,0 @@ -from __future__ import annotations - -import time -from collections.abc import Mapping -from dataclasses import dataclass - -import pytest - -from core.workflow.system_variables import build_system_variables -from graphon.entities import GraphInitParams -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeExecutionType, NodeType -from graphon.graph import Graph -from graphon.graph.validation import GraphValidationError -from graphon.nodes.base.node import Node -from graphon.runtime import GraphRuntimeState, VariablePool -from tests.workflow_test_utils import build_test_graph_init_params - - -class _TestNodeData(BaseNodeData): - type: NodeType | None = None - execution_type: NodeExecutionType | str | None = None - - -class _TestNode(Node[_TestNodeData]): - node_type = BuiltinNodeTypes.ANSWER - execution_type = NodeExecutionType.EXECUTABLE - - @classmethod - def version(cls) -> str: - return "1" - - def __init__( - self, - *, - id: str, - config: Mapping[str, object], - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - node_type_value = self.data.get("type") - if isinstance(node_type_value, str): - self.node_type = node_type_value - - def _run(self): - raise NotImplementedError - - def post_init(self) -> None: - super().post_init() - self._maybe_override_execution_type() - self.data = dict(self.node_data.model_dump()) - - def _maybe_override_execution_type(self) -> None: - execution_type_value = self.node_data.execution_type - if execution_type_value is None: - return - if isinstance(execution_type_value, NodeExecutionType): - self.execution_type = execution_type_value - else: - self.execution_type = NodeExecutionType(execution_type_value) - - -@dataclass(slots=True) -class _SimpleNodeFactory: - graph_init_params: GraphInitParams - graph_runtime_state: GraphRuntimeState - - def create_node(self, node_config: Mapping[str, object]) -> _TestNode: - node_id = str(node_config["id"]) - node = _TestNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) - return node - - -@pytest.fixture -def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: - graph_config: dict[str, object] = {"edges": [], "nodes": []} - init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="service-api", - call_depth=0, - ) - variable_pool = VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}) - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state) - return factory, graph_config - - -def test_graph_initialization_runs_default_validators( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -): - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - {"id": "answer", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "Answer"}}, - ] - graph_config["edges"] = [ - {"source": "start", "target": "answer", "sourceHandle": "success"}, - ] - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert graph.root_node.id == "start" - assert "answer" in graph.nodes - - -def test_graph_validation_fails_for_unknown_edge_targets( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - ] - graph_config["edges"] = [ - {"source": "start", "target": "missing", "sourceHandle": "success"}, - ] - - with pytest.raises(GraphValidationError) as exc: - Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert any(issue.code == "MISSING_NODE" for issue in exc.value.issues) - - -def test_graph_promotes_fail_branch_nodes_to_branch_execution_type( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - { - "id": "branch", - "data": { - "type": BuiltinNodeTypes.IF_ELSE, - "title": "Branch", - "error_strategy": ErrorStrategy.FAIL_BRANCH, - }, - }, - ] - graph_config["edges"] = [ - {"source": "start", "target": "branch", "sourceHandle": "success"}, - ] - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH - - -def test_graph_init_ignores_custom_note_nodes_before_node_data_validation( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - {"id": "answer", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "Answer"}}, - { - "id": "note", - "type": "custom-note", - "data": { - "type": "", - "title": "", - "desc": "", - "text": "{}", - "theme": "blue", - }, - }, - ] - graph_config["edges"] = [ - {"source": "start", "target": "answer", "sourceHandle": "success"}, - ] - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert graph.root_node.id == "start" - assert "answer" in graph.nodes - assert "note" not in graph.nodes - - -def test_graph_init_fails_for_unknown_root_node_id( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - ] - graph_config["edges"] = [] - - with pytest.raises(ValueError, match="Root node id missing not found in the graph"): - Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="missing") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/README.md b/api/tests/unit_tests/core/workflow/graph_engine/README.md index 960fef7d43..dd419f0810 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/README.md +++ b/api/tests/unit_tests/core/workflow/graph_engine/README.md @@ -1,441 +1,30 @@ -# Graph Engine Testing Framework +# Workflow Graph Engine Smoke Tests -## Overview +This directory now keeps only a small Dify-owned smoke layer around the external +`graphon` package. -This directory contains a comprehensive testing framework for the Graph Engine, including: +Retained coverage focuses on: -1. **TableTestRunner** - Advanced table-driven test framework for workflow testing -1. **Auto-Mock System** - Powerful mocking framework for testing without external dependencies +1. Dify workflow layers: + - `layers/test_llm_quota.py` + - `layers/test_observability.py` +2. Human-input resume integration: + - `test_parallel_human_input_join_resume.py` +3. One mocked tool/chatflow smoke path: + - `test_tool_in_chatflow.py` -## TableTestRunner Framework +The helper modules below remain only because the retained smoke tests use them: -The TableTestRunner (`test_table_runner.py`) provides a robust table-driven testing framework for GraphEngine workflows. +1. `test_mock_config.py` +2. `test_mock_factory.py` +3. `test_mock_nodes.py` +4. `test_table_runner.py` -### Features - -- **Table-driven testing** - Define test cases as structured data -- **Parallel test execution** - Run tests concurrently for faster execution -- **Property-based testing** - Integration with Hypothesis for fuzzing -- **Event sequence validation** - Verify correct event ordering -- **Mock configuration** - Seamless integration with the auto-mock system -- **Performance metrics** - Track execution times and bottlenecks -- **Detailed error reporting** - Comprehensive failure diagnostics - -### Basic Usage - -```python -from test_table_runner import TableTestRunner, WorkflowTestCase - -# Create test runner -runner = TableTestRunner() - -# Define test case -test_case = WorkflowTestCase( - fixture_path="simple_workflow", - inputs={"query": "Hello"}, - expected_outputs={"result": "World"}, - description="Basic workflow test", -) - -# Run single test -result = runner.run_test_case(test_case) -assert result.success -``` - -### Advanced Features - -#### Parallel Execution - -```python -runner = TableTestRunner(max_workers=8) - -test_cases = [ - WorkflowTestCase(...), - WorkflowTestCase(...), - # ... more test cases -] - -# Run tests in parallel -suite_result = runner.run_table_tests( - test_cases, - parallel=True, - fail_fast=False -) - -print(f"Success rate: {suite_result.success_rate:.1f}%") -``` - -#### Event Sequence Validation - -```python -from graphon.graph_events import ( - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, -) - -test_case = WorkflowTestCase( - fixture_path="workflow", - inputs={}, - expected_outputs={}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ] -) -``` - -### Test Suite Reports - -```python -# Run test suite -suite_result = runner.run_table_tests(test_cases) - -# Generate detailed report -report = runner.generate_report(suite_result) -print(report) - -# Access specific results -failed_results = suite_result.get_failed_results() -for result in failed_results: - print(f"Failed: {result.test_case.description}") - print(f" Error: {result.error}") -``` - -### Performance Testing - -```python -# Enable logging for performance insights -runner = TableTestRunner( - enable_logging=True, - log_level="DEBUG" -) - -# Run tests and analyze performance -suite_result = runner.run_table_tests(test_cases) - -# Get slowest tests -sorted_results = sorted( - suite_result.results, - key=lambda r: r.execution_time, - reverse=True -) - -print("Slowest tests:") -for result in sorted_results[:5]: - print(f" {result.test_case.description}: {result.execution_time:.2f}s") -``` - -## Integration: TableTestRunner + Auto-Mock System - -The TableTestRunner seamlessly integrates with the auto-mock system for comprehensive workflow testing: - -```python -from test_table_runner import TableTestRunner, WorkflowTestCase -from test_mock_config import MockConfigBuilder - -# Configure mocks -mock_config = (MockConfigBuilder() - .with_llm_response("Mocked LLM response") - .with_tool_response({"result": "mocked"}) - .with_delays(True) # Simulate realistic delays - .build()) - -# Create test case with mocking -test_case = WorkflowTestCase( - fixture_path="complex_workflow", - inputs={"query": "test"}, - expected_outputs={"answer": "Mocked LLM response"}, - use_auto_mock=True, # Enable auto-mocking - mock_config=mock_config, - description="Test with mocked services", -) - -# Run test -runner = TableTestRunner() -result = runner.run_test_case(test_case) -``` - -## Auto-Mock System - -The auto-mock system provides a powerful framework for testing workflows that contain nodes requiring third-party services (LLM, APIs, tools, etc.) without making actual external calls. This enables: - -- **Fast test execution** - No network latency or API rate limits -- **Deterministic results** - Consistent outputs for reliable testing -- **Cost savings** - No API usage charges during testing -- **Offline testing** - Tests can run without internet connectivity -- **Error simulation** - Test error handling without triggering real failures - -## Architecture - -The auto-mock system consists of three main components: - -### 1. MockNodeFactory (`test_mock_factory.py`) - -- Extends `DifyNodeFactory` to intercept node creation -- Automatically detects nodes requiring third-party services -- Returns mock node implementations instead of real ones -- Supports registration of custom mock implementations - -### 2. Mock Node Implementations (`test_mock_nodes.py`) - -- `MockLLMNode` - Mocks LLM API calls (OpenAI, Anthropic, etc.) -- `MockAgentNode` - Mocks agent execution -- `MockToolNode` - Mocks tool invocations -- `MockKnowledgeRetrievalNode` - Mocks knowledge base queries -- `MockHttpRequestNode` - Mocks HTTP requests -- `MockParameterExtractorNode` - Mocks parameter extraction -- `MockDocumentExtractorNode` - Mocks document processing -- `MockQuestionClassifierNode` - Mocks question classification - -### 3. Mock Configuration (`test_mock_config.py`) - -- `MockConfig` - Global configuration for mock behavior -- `NodeMockConfig` - Node-specific mock configuration -- `MockConfigBuilder` - Fluent interface for building configurations - -## Usage - -### Basic Example - -```python -from test_graph_engine import TableTestRunner, WorkflowTestCase -from test_mock_config import MockConfigBuilder - -# Create test runner -runner = TableTestRunner() - -# Configure mock responses -mock_config = (MockConfigBuilder() - .with_llm_response("Mocked LLM response") - .build()) - -# Define test case -test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Hello"}, - expected_outputs={"answer": "Mocked LLM response"}, - use_auto_mock=True, # Enable auto-mocking - mock_config=mock_config, -) - -# Run test -result = runner.run_test_case(test_case) -assert result.success -``` - -### Custom Node Outputs - -```python -# Configure specific outputs for individual nodes -mock_config = MockConfig() -mock_config.set_node_outputs("llm_node_123", { - "text": "Custom response for this specific node", - "usage": {"total_tokens": 50}, - "finish_reason": "stop", -}) -``` - -### Error Simulation - -```python -# Simulate node failures for error handling tests -mock_config = MockConfig() -mock_config.set_node_error("http_node", "Connection timeout") -``` - -### Simulated Delays - -```python -# Add realistic execution delays -from test_mock_config import NodeMockConfig - -node_config = NodeMockConfig( - node_id="llm_node", - outputs={"text": "Response"}, - delay=1.5, # 1.5 second delay -) -mock_config.set_node_config("llm_node", node_config) -``` - -### Custom Handlers - -```python -# Define custom logic for mock outputs -def custom_handler(node): - # Access node state and return dynamic outputs - return { - "text": f"Processed: {node.graph_runtime_state.variable_pool.get('query')}", - } - -node_config = NodeMockConfig( - node_id="llm_node", - custom_handler=custom_handler, -) -``` - -## Node Types Automatically Mocked - -The following node types are automatically mocked when `use_auto_mock=True`: - -- `LLM` - Language model nodes -- `AGENT` - Agent execution nodes -- `TOOL` - Tool invocation nodes -- `KNOWLEDGE_RETRIEVAL` - Knowledge base query nodes -- `HTTP_REQUEST` - HTTP request nodes -- `PARAMETER_EXTRACTOR` - Parameter extraction nodes -- `DOCUMENT_EXTRACTOR` - Document processing nodes -- `QUESTION_CLASSIFIER` - Question classification nodes - -## Advanced Features - -### Registering Custom Mock Implementations - -```python -from test_mock_factory import MockNodeFactory - -# Create custom mock implementation -class CustomMockNode(BaseNode): - def _run(self): - # Custom mock logic - pass - -# Register for a specific node type -factory = MockNodeFactory(...) -factory.register_mock_node_type(NodeType.CUSTOM, CustomMockNode) -``` - -### Default Configurations by Node Type - -```python -# Set defaults for all nodes of a specific type -mock_config.set_default_config(NodeType.LLM, { - "temperature": 0.7, - "max_tokens": 100, -}) -``` - -### MockConfigBuilder Fluent API - -```python -config = (MockConfigBuilder() - .with_llm_response("LLM response") - .with_agent_response("Agent response") - .with_tool_response({"result": "data"}) - .with_retrieval_response("Retrieved content") - .with_http_response({"status_code": 200, "body": "{}"}) - .with_node_output("node_id", {"output": "value"}) - .with_node_error("error_node", "Error message") - .with_delays(True) - .build()) -``` - -## Testing Workflows - -### 1. Create Workflow Fixture - -Create a YAML fixture file in `api/tests/fixtures/workflow/` directory defining your workflow graph. - -### 2. Configure Mocks - -Set up mock configurations for nodes that need third-party services. - -### 3. Define Test Cases - -Create `WorkflowTestCase` instances with inputs, expected outputs, and mock config. - -### 4. Run Tests - -Use `TableTestRunner` to execute test cases and validate results. - -## Best Practices - -1. **Use descriptive mock responses** - Make it clear in outputs that they are mocked -1. **Test both success and failure paths** - Use error simulation to test error handling -1. **Keep mock configs close to tests** - Define mocks in the same test file for clarity -1. **Use custom handlers sparingly** - Only when dynamic behavior is needed -1. **Document mock behavior** - Comment why specific mock values are chosen -1. **Validate mock accuracy** - Ensure mocks reflect real service behavior - -## Examples - -See `test_mock_example.py` for comprehensive examples including: - -- Basic LLM workflow testing -- Custom node outputs -- HTTP and tool workflow testing -- Error simulation -- Performance testing with delays - -## Running Tests - -### TableTestRunner Tests +Examples: ```bash -# Run graph engine tests (includes property-based tests) -uv run pytest api/tests/unit_tests/graphon/graph_engine/test_graph_engine.py - -# Run with specific test patterns -uv run pytest api/tests/unit_tests/graphon/graph_engine/test_graph_engine.py -k "test_echo" - -# Run with verbose output -uv run pytest api/tests/unit_tests/graphon/graph_engine/test_graph_engine.py -v +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py ``` - -### Mock System Tests - -```bash -# Run auto-mock system tests -uv run pytest api/tests/unit_tests/graphon/graph_engine/test_auto_mock_system.py - -# Run examples -uv run python api/tests/unit_tests/graphon/graph_engine/test_mock_example.py - -# Run simple validation -uv run python api/tests/unit_tests/graphon/graph_engine/test_mock_simple.py -``` - -### All Tests - -```bash -# Run all graph engine tests -uv run pytest api/tests/unit_tests/graphon/graph_engine/ - -# Run with coverage -uv run pytest api/tests/unit_tests/graphon/graph_engine/ --cov=graphon.graph_engine - -# Run in parallel -uv run pytest api/tests/unit_tests/graphon/graph_engine/ -n auto -``` - -## Troubleshooting - -### Issue: Mock not being applied - -- Ensure `use_auto_mock=True` in `WorkflowTestCase` -- Verify node ID matches in mock config -- Check that node type is in the auto-mock list - -### Issue: Unexpected outputs - -- Debug by printing `result.actual_outputs` -- Check if custom handler is overriding expected outputs -- Verify mock config is properly built - -### Issue: Import errors - -- Ensure all mock modules are in the correct path -- Check that required dependencies are installed - -## Future Enhancements - -Potential improvements to the auto-mock system: - -1. **Recording and playback** - Record real API responses for replay in tests -1. **Mock templates** - Pre-defined mock configurations for common scenarios -1. **Async support** - Better support for async node execution -1. **Mock validation** - Validate mock outputs against node schemas -1. **Performance profiling** - Built-in performance metrics for mocked workflows diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py deleted file mode 100644 index 795362b158..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ /dev/null @@ -1,315 +0,0 @@ -"""Tests for Redis command channel implementation.""" - -import json -from unittest.mock import MagicMock - -from graphon.graph_engine.command_channels.redis_channel import RedisChannel -from graphon.graph_engine.entities.commands import ( - AbortCommand, - CommandType, - GraphEngineCommand, - UpdateVariablesCommand, - VariableUpdate, -) -from graphon.variables import IntegerVariable, StringVariable - - -class TestRedisChannel: - """Test suite for RedisChannel functionality.""" - - def test_init(self): - """Test RedisChannel initialization.""" - mock_redis = MagicMock() - channel_key = "test:channel:key" - ttl = 7200 - - channel = RedisChannel(mock_redis, channel_key, ttl) - - assert channel._redis == mock_redis - assert channel._key == channel_key - assert channel._command_ttl == ttl - - def test_init_default_ttl(self): - """Test RedisChannel initialization with default TTL.""" - mock_redis = MagicMock() - channel_key = "test:channel:key" - - channel = RedisChannel(mock_redis, channel_key) - - assert channel._command_ttl == 3600 # Default TTL - - def test_send_command(self): - """Test sending a command to Redis.""" - mock_redis = MagicMock() - mock_pipe = MagicMock() - context = MagicMock() - context.__enter__.return_value = mock_pipe - context.__exit__.return_value = None - mock_redis.pipeline.return_value = context - - channel = RedisChannel(mock_redis, "test:key", 3600) - - pending_key = "test:key:pending" - - # Create a test command - command = GraphEngineCommand(command_type=CommandType.ABORT) - - # Send the command - channel.send_command(command) - - # Verify pipeline was used - mock_redis.pipeline.assert_called_once() - - # Verify rpush was called with correct data - expected_json = json.dumps(command.model_dump()) - mock_pipe.rpush.assert_called_once_with("test:key", expected_json) - - # Verify expire was set - mock_pipe.expire.assert_called_once_with("test:key", 3600) - mock_pipe.set.assert_called_once_with(pending_key, "1", ex=3600) - - # Verify execute was called - mock_pipe.execute.assert_called_once() - - def test_fetch_commands_empty(self): - """Test fetching commands when Redis list is empty.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context] - - # No pending marker - pending_pipe.execute.return_value = [None, 0] - mock_redis.llen.return_value = 0 - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert commands == [] - mock_redis.pipeline.assert_called_once() - fetch_pipe.lrange.assert_not_called() - fetch_pipe.delete.assert_not_called() - - def test_fetch_commands_with_abort_command(self): - """Test fetching abort commands from Redis.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Create abort command data - abort_command = AbortCommand() - command_json = json.dumps(abort_command.model_dump()) - - # Simulate Redis returning one command - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert len(commands) == 1 - assert isinstance(commands[0], AbortCommand) - assert commands[0].command_type == CommandType.ABORT - - def test_fetch_commands_multiple(self): - """Test fetching multiple commands from Redis.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Create multiple commands - command1 = GraphEngineCommand(command_type=CommandType.ABORT) - command2 = AbortCommand() - - command1_json = json.dumps(command1.model_dump()) - command2_json = json.dumps(command2.model_dump()) - - # Simulate Redis returning multiple commands - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command1_json.encode(), command2_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert len(commands) == 2 - assert commands[0].command_type == CommandType.ABORT - assert isinstance(commands[1], AbortCommand) - - def test_fetch_commands_with_update_variables_command(self): - """Test fetching update variables command from Redis.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - update_command = UpdateVariablesCommand( - updates=[ - VariableUpdate( - value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]), - ), - VariableUpdate( - value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]), - ), - ] - ) - command_json = json.dumps(update_command.model_dump()) - - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert len(commands) == 1 - assert isinstance(commands[0], UpdateVariablesCommand) - assert isinstance(commands[0].updates[0].value, StringVariable) - assert list(commands[0].updates[0].value.selector) == ["node1", "foo"] - assert commands[0].updates[0].value.value == "bar" - - def test_fetch_commands_skips_invalid_json(self): - """Test that invalid JSON commands are skipped.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Mix valid and invalid JSON - valid_command = AbortCommand() - valid_json = json.dumps(valid_command.model_dump()) - invalid_json = b"invalid json {" - - # Simulate Redis returning mixed valid/invalid commands - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[invalid_json, valid_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - # Should only return the valid command - assert len(commands) == 1 - assert isinstance(commands[0], AbortCommand) - - def test_deserialize_command_abort(self): - """Test deserializing an abort command.""" - channel = RedisChannel(MagicMock(), "test:key") - - abort_data = {"command_type": CommandType.ABORT} - command = channel._deserialize_command(abort_data) - - assert isinstance(command, AbortCommand) - assert command.command_type == CommandType.ABORT - - def test_deserialize_command_generic(self): - """Test deserializing a generic command.""" - channel = RedisChannel(MagicMock(), "test:key") - - # For now, only ABORT is supported, but test generic handling - generic_data = {"command_type": CommandType.ABORT} - command = channel._deserialize_command(generic_data) - - assert command is not None - assert command.command_type == CommandType.ABORT - - def test_deserialize_command_invalid(self): - """Test deserializing invalid command data.""" - channel = RedisChannel(MagicMock(), "test:key") - - # Missing command_type - invalid_data = {"some_field": "value"} - command = channel._deserialize_command(invalid_data) - - assert command is None - - def test_deserialize_command_invalid_type(self): - """Test deserializing command with invalid type.""" - channel = RedisChannel(MagicMock(), "test:key") - - # Invalid command type - invalid_data = {"command_type": "INVALID_TYPE"} - command = channel._deserialize_command(invalid_data) - - assert command is None - - def test_atomic_fetch_and_clear(self): - """Test that fetch_commands atomically fetches and clears the list.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - command = AbortCommand() - command_json = json.dumps(command.model_dump()) - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - - # First fetch should return the command - commands = channel.fetch_commands() - assert len(commands) == 1 - - # Verify both lrange and delete were called in the pipeline - assert fetch_pipe.lrange.call_count == 1 - assert fetch_pipe.delete.call_count == 1 - fetch_pipe.lrange.assert_called_with("test:key", 0, -1) - fetch_pipe.delete.assert_called_with("test:key") - - def test_fetch_commands_without_pending_marker_returns_empty(self): - """Ensure we avoid unnecessary list reads when pending flag is missing.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Pending flag absent - pending_pipe.execute.return_value = [None, 0] - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert commands == [] - mock_redis.llen.assert_not_called() - assert mock_redis.pipeline.call_count == 1 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py deleted file mode 100644 index cacbe9ba4e..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Tests for graph engine event handlers.""" - -from __future__ import annotations - -from graphon.entities.base_node_data import RetryConfig -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.graph_engine.domain.graph_execution import GraphExecution -from graphon.graph_engine.event_management.event_handlers import EventHandler -from graphon.graph_engine.event_management.event_manager import EventManager -from graphon.graph_engine.graph_state_manager import GraphStateManager -from graphon.graph_engine.ready_queue.in_memory import InMemoryReadyQueue -from graphon.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator -from graphon.graph_events import NodeRunRetryEvent, NodeRunStartedEvent -from graphon.node_events import NodeRunResult -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now - - -class _StubEdgeProcessor: - """Minimal edge processor stub for tests.""" - - -class _StubErrorHandler: - """Minimal error handler stub for tests.""" - - -class _StubNode: - """Simple node stub exposing the attributes needed by the state manager.""" - - def __init__(self, node_id: str) -> None: - self.id = node_id - self.state = NodeState.UNKNOWN - self.title = "Stub Node" - self.execution_type = NodeExecutionType.EXECUTABLE - self.error_strategy = None - self.retry_config = RetryConfig() - self.retry = False - - -def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]: - """Construct an EventHandler with in-memory dependencies for testing.""" - - node = _StubNode(node_id) - graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node) - - variable_pool = VariablePool() - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) - graph_execution = GraphExecution(workflow_id="test-workflow") - - event_manager = EventManager() - state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue()) - response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph) - - handler = EventHandler( - graph=graph, - graph_runtime_state=runtime_state, - graph_execution=graph_execution, - response_coordinator=response_coordinator, - event_collector=event_manager, - edge_processor=_StubEdgeProcessor(), - state_manager=state_manager, - error_handler=_StubErrorHandler(), - ) - - return handler, event_manager, graph_execution - - -def test_retry_does_not_emit_additional_start_event() -> None: - """Ensure retry attempts do not produce duplicate start events.""" - - node_id = "test-node" - handler, event_manager, graph_execution = _build_event_handler(node_id) - - execution_id = "exec-1" - node_type = BuiltinNodeTypes.CODE - start_time = naive_utc_now() - - start_event = NodeRunStartedEvent( - id=execution_id, - node_id=node_id, - node_type=node_type, - node_title="Stub Node", - start_at=start_time, - ) - handler.dispatch(start_event) - - retry_event = NodeRunRetryEvent( - id=execution_id, - node_id=node_id, - node_type=node_type, - node_title="Stub Node", - start_at=start_time, - error="boom", - retry_index=1, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error="boom", - error_type="TestError", - ), - ) - handler.dispatch(retry_event) - - # Simulate the node starting execution again after retry - second_start_event = NodeRunStartedEvent( - id=execution_id, - node_id=node_id, - node_type=node_type, - node_title="Stub Node", - start_at=start_time, - ) - handler.dispatch(second_start_event) - - collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined] - - assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent] - - node_execution = graph_execution.get_or_create_node_execution(node_id) - assert node_execution.retry_count == 1 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py deleted file mode 100644 index dc0998caf1..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Tests for the EventManager.""" - -from __future__ import annotations - -import logging - -from graphon.graph_engine.event_management.event_manager import EventManager -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent - - -class _FaultyLayer(GraphEngineLayer): - """Layer that raises from on_event to test error handling.""" - - def on_graph_start(self) -> None: # pragma: no cover - not used in tests - pass - - def on_event(self, event: GraphEngineEvent) -> None: - raise RuntimeError("boom") - - def on_graph_end(self, error: Exception | None) -> None: # pragma: no cover - not used in tests - pass - - -def test_event_manager_logs_layer_errors(caplog) -> None: - """Ensure errors raised by layers are logged when collecting events.""" - - event_manager = EventManager() - event_manager.set_layers([_FaultyLayer()]) - - with caplog.at_level(logging.ERROR): - event_manager.collect(GraphEngineEvent()) - - error_logs = [record for record in caplog.records if "Error in layer on_event" in record.getMessage()] - assert error_logs, "Expected layer errors to be logged" - - log_record = error_logs[0] - assert log_record.exc_info is not None - assert isinstance(log_record.exc_info[1], RuntimeError) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py deleted file mode 100644 index cf8811dc2b..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for graph traversal components.""" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py deleted file mode 100644 index b030496eb1..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py +++ /dev/null @@ -1,307 +0,0 @@ -"""Unit tests for skip propagator.""" - -from unittest.mock import MagicMock, create_autospec - -from graphon.graph import Edge, Graph -from graphon.graph_engine.graph_state_manager import GraphStateManager -from graphon.graph_engine.graph_traversal.skip_propagator import SkipPropagator - - -class TestSkipPropagator: - """Test suite for SkipPropagator.""" - - def test_propagate_skip_from_edge_with_unknown_edges_stops_processing(self) -> None: - """When there are unknown incoming edges, propagation should stop.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create a mock edge - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - # Setup graph edges dict - mock_graph.edges = {"edge_1": mock_edge} - - # Setup incoming edges - incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Setup state manager to return has_unknown=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": True, - "has_taken": False, - "all_skipped": False, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - mock_graph.get_incoming_edges.assert_called_once_with("node_2") - mock_state_manager.analyze_edge_states.assert_called_once_with(incoming_edges) - # Should not call any other state manager methods - mock_state_manager.enqueue_node.assert_not_called() - mock_state_manager.start_execution.assert_not_called() - mock_state_manager.mark_node_skipped.assert_not_called() - - def test_propagate_skip_from_edge_with_taken_edge_enqueues_node(self) -> None: - """When there is at least one taken edge, node should be enqueued.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create a mock edge - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - mock_graph.edges = {"edge_1": mock_edge} - incoming_edges = [MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Setup state manager to return has_taken=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": True, - "all_skipped": False, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - mock_state_manager.enqueue_node.assert_called_once_with("node_2") - mock_state_manager.start_execution.assert_called_once_with("node_2") - mock_state_manager.mark_node_skipped.assert_not_called() - - def test_propagate_skip_from_edge_with_all_skipped_propagates_to_node(self) -> None: - """When all incoming edges are skipped, should propagate skip to node.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create a mock edge - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - mock_graph.edges = {"edge_1": mock_edge} - incoming_edges = [MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Setup state manager to return all_skipped=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": False, - "all_skipped": True, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") - mock_state_manager.enqueue_node.assert_not_called() - mock_state_manager.start_execution.assert_not_called() - - def test_propagate_skip_to_node_marks_node_and_outgoing_edges_skipped(self) -> None: - """_propagate_skip_to_node should mark node and all outgoing edges as skipped.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create outgoing edges - edge1 = MagicMock(spec=Edge) - edge1.id = "edge_2" - edge1.head = "node_downstream_1" # Set head for propagate_skip_from_edge - - edge2 = MagicMock(spec=Edge) - edge2.id = "edge_3" - edge2.head = "node_downstream_2" - - # Setup graph edges dict for propagate_skip_from_edge - mock_graph.edges = {"edge_2": edge1, "edge_3": edge2} - mock_graph.get_outgoing_edges.return_value = [edge1, edge2] - - # Setup get_incoming_edges to return empty list to stop recursion - mock_graph.get_incoming_edges.return_value = [] - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Use mock to call private method - # Act - propagator._propagate_skip_to_node("node_1") - - # Assert - mock_state_manager.mark_node_skipped.assert_called_once_with("node_1") - mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") - mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") - assert mock_state_manager.mark_edge_skipped.call_count == 2 - # Should recursively propagate from each edge - # Since propagate_skip_from_edge is called, we need to verify it was called - # But we can't directly verify due to recursion. We'll trust the logic. - - def test_skip_branch_paths_marks_unselected_edges_and_propagates(self) -> None: - """skip_branch_paths should mark all unselected edges as skipped and propagate.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create unselected edges - edge1 = MagicMock(spec=Edge) - edge1.id = "edge_1" - edge1.head = "node_downstream_1" - - edge2 = MagicMock(spec=Edge) - edge2.id = "edge_2" - edge2.head = "node_downstream_2" - - unselected_edges = [edge1, edge2] - - # Setup graph edges dict - mock_graph.edges = {"edge_1": edge1, "edge_2": edge2} - # Setup get_incoming_edges to return empty list to stop recursion - mock_graph.get_incoming_edges.return_value = [] - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.skip_branch_paths(unselected_edges) - - # Assert - mock_state_manager.mark_edge_skipped.assert_any_call("edge_1") - mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") - assert mock_state_manager.mark_edge_skipped.call_count == 2 - # propagate_skip_from_edge should be called for each edge - # We can't directly verify due to the mock, but the logic is covered - - def test_propagate_skip_from_edge_recursively_propagates_through_graph(self) -> None: - """Skip propagation should recursively propagate through the graph.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create edge chain: edge_1 -> node_2 -> edge_3 -> node_4 - edge1 = MagicMock(spec=Edge) - edge1.id = "edge_1" - edge1.head = "node_2" - - edge3 = MagicMock(spec=Edge) - edge3.id = "edge_3" - edge3.head = "node_4" - - mock_graph.edges = {"edge_1": edge1, "edge_3": edge3} - - # Setup get_incoming_edges to return different values based on node - def get_incoming_edges_side_effect(node_id): - if node_id == "node_2": - return [edge1] - elif node_id == "node_4": - return [edge3] - return [] - - mock_graph.get_incoming_edges.side_effect = get_incoming_edges_side_effect - - # Setup get_outgoing_edges to return different values based on node - def get_outgoing_edges_side_effect(node_id): - if node_id == "node_2": - return [edge3] - elif node_id == "node_4": - return [] # No outgoing edges, stops recursion - return [] - - mock_graph.get_outgoing_edges.side_effect = get_outgoing_edges_side_effect - - # Setup state manager to return all_skipped for both nodes - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": False, - "all_skipped": True, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - # Should mark node_2 as skipped - mock_state_manager.mark_node_skipped.assert_any_call("node_2") - # Should mark edge_3 as skipped - mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") - # Should propagate to node_4 - mock_state_manager.mark_node_skipped.assert_any_call("node_4") - assert mock_state_manager.mark_node_skipped.call_count == 2 - - def test_propagate_skip_from_edge_with_mixed_edge_states_handles_correctly(self) -> None: - """Test with mixed edge states (some unknown, some taken, some skipped).""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - mock_graph.edges = {"edge_1": mock_edge} - incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge), MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Test 1: has_unknown=True, has_taken=False, all_skipped=False - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": True, - "has_taken": False, - "all_skipped": False, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - should stop processing - mock_state_manager.enqueue_node.assert_not_called() - mock_state_manager.mark_node_skipped.assert_not_called() - - # Reset mocks for next test - mock_state_manager.reset_mock() - mock_graph.reset_mock() - - # Test 2: has_unknown=False, has_taken=True, all_skipped=False - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": True, - "all_skipped": False, - } - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - should enqueue node - mock_state_manager.enqueue_node.assert_called_once_with("node_2") - mock_state_manager.start_execution.assert_called_once_with("node_2") - - # Reset mocks for next test - mock_state_manager.reset_mock() - mock_graph.reset_mock() - - # Test 3: has_unknown=False, has_taken=False, all_skipped=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": False, - "all_skipped": True, - } - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - should propagate skip - mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py deleted file mode 100644 index 2fead1d719..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Utilities for testing HumanInputNode without database dependencies.""" - -from __future__ import annotations - -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRecipientEntity, - HumanInputFormRepository, -) -from graphon.nodes.human_input.enums import HumanInputFormStatus -from libs.datetime_utils import naive_utc_now - - -class _InMemoryFormRecipient(HumanInputFormRecipientEntity): - """Minimal recipient entity required by the repository interface.""" - - def __init__(self, recipient_id: str, token: str) -> None: - self._id = recipient_id - self._token = token - - @property - def id(self) -> str: - return self._id - - @property - def token(self) -> str: - return self._token - - -@dataclass -class _InMemoryFormEntity(HumanInputFormEntity): - form_id: str - rendered: str - token: str | None = None - action_id: str | None = None - data: Mapping[str, Any] | None = None - is_submitted: bool = False - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() - - @property - def id(self) -> str: - return self.form_id - - @property - def submission_token(self) -> str | None: - return self.token - - @property - def recipients(self) -> list[HumanInputFormRecipientEntity]: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class InMemoryHumanInputFormRepository(HumanInputFormRepository): - """Pure in-memory repository used by workflow graph engine tests.""" - - def __init__(self) -> None: - self._form_counter = 0 - self.created_params: list[FormCreateParams] = [] - self.created_forms: list[_InMemoryFormEntity] = [] - self._forms_by_node_id: dict[str, _InMemoryFormEntity] = {} - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - self.created_params.append(params) - self._form_counter += 1 - form_id = f"form-{self._form_counter}" - token = f"token-{form_id}" - entity = _InMemoryFormEntity( - form_id=form_id, - rendered=params.rendered_content, - token=token, - ) - self.created_forms.append(entity) - self._forms_by_node_id[params.node_id] = entity - return entity - - def get_form(self, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_node_id.get(node_id) - - # Convenience helpers for tests ------------------------------------- - - def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: - """Simulate a human submission for the next repository lookup.""" - - if not self.created_forms: - raise AssertionError("no form has been created to attach submission data") - entity = self.created_forms[-1] - entity.action_id = action_id - entity.data = form_data or {} - entity.is_submitted = True - entity.status_value = HumanInputFormStatus.SUBMITTED - entity.expiration = naive_utc_now() + timedelta(days=1) - - def clear_submission(self) -> None: - if not self.created_forms: - return - for form in self.created_forms: - form.action_id = None - form.data = None - form.is_submitted = False - form.status_value = HumanInputFormStatus.WAITING diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index b642dc82fe..41627f5e0b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -5,13 +5,12 @@ Shared fixtures for ObservabilityLayer tests. from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import set_tracer_provider -from graphon.enums import BuiltinNodeTypes - @pytest.fixture def memory_span_exporter(): @@ -62,9 +61,10 @@ def mock_llm_node(): @pytest.fixture def mock_tool_node(): """Create a mock Tool Node with tool-specific attributes.""" - from core.tools.entities.tool_entities import ToolProviderType from graphon.nodes.tool.entities import ToolNodeData + from core.tools.entities.tool_entities import ToolProviderType + node = MagicMock() node.id = "test-tool-node-id" node.title = "Test Tool Node" @@ -117,8 +117,8 @@ def mock_result_event(): """Create a mock result event with NodeRunResult.""" from datetime import datetime - from graphon.graph_events.node import NodeRunSucceededEvent - from graphon.node_events.base import NodeRunResult + from graphon.graph_events import NodeRunSucceededEvent + from graphon.node_events import NodeRunResult node_run_result = NodeRunResult( inputs={"query": "test query"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py deleted file mode 100644 index 7ff77c19c1..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations - -import pytest - -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_engine.layers.base import ( - GraphEngineLayer, - GraphEngineLayerNotInitializedError, -) -from graphon.graph_events import GraphEngineEvent - -from ..test_table_runner import WorkflowRunner - - -class LayerForTest(GraphEngineLayer): - def on_graph_start(self) -> None: - pass - - def on_event(self, event: GraphEngineEvent) -> None: - pass - - def on_graph_end(self, error: Exception | None) -> None: - pass - - -def test_layer_runtime_state_raises_when_uninitialized() -> None: - layer = LayerForTest() - - with pytest.raises(GraphEngineLayerNotInitializedError): - _ = layer.graph_runtime_state - - -def test_layer_runtime_state_available_after_engine_layer() -> None: - runner = WorkflowRunner() - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture( - fixture_data, - inputs={"query": "test layer state"}, - ) - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - layer = LayerForTest() - engine.layer(layer) - - outputs = layer.graph_runtime_state.outputs - ready_queue_size = layer.graph_runtime_state.ready_queue_size - - assert outputs == {} - assert isinstance(ready_queue_size, int) - assert ready_queue_size >= 0 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 80874e768a..99d131737e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -3,15 +3,16 @@ from datetime import datetime from types import SimpleNamespace from unittest.mock import MagicMock, patch +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.entities.commands import CommandType +from graphon.graph_events import NodeRunSucceededEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import NodeRunResult + from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph_engine.entities.commands import CommandType -from graphon.graph_events.node import NodeRunSucceededEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import NodeRunResult def _build_dify_context() -> DifyRunContext: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index 14ce55938d..9cf72763ee 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -13,10 +13,10 @@ Test coverage: from unittest.mock import patch import pytest +from graphon.enums import BuiltinNodeTypes from opentelemetry.trace import StatusCode from core.app.workflow.layers.observability import ObservabilityLayer -from graphon.enums import BuiltinNodeTypes class TestObservabilityLayerInitialization: @@ -144,7 +144,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node, mock_result_event ): """Test that LLM parser is used for LLM nodes and extracts LLM-specific attributes.""" - from graphon.node_events.base import NodeRunResult + from graphon.node_events import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={}, @@ -182,7 +182,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_retrieval_node, mock_result_event ): """Test that retrieval parser is used for retrieval nodes and extracts retrieval-specific attributes.""" - from graphon.node_events.base import NodeRunResult + from graphon.node_events import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"query": "test query"}, @@ -210,7 +210,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node, mock_result_event ): """Test that result_event parameter allows parsers to extract inputs and outputs.""" - from graphon.node_events.base import NodeRunResult + from graphon.node_events import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"input_key": "input_value"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py deleted file mode 100644 index ab3a31f673..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py +++ /dev/null @@ -1,189 +0,0 @@ -"""Tests for dispatcher command checking behavior.""" - -from __future__ import annotations - -import queue -from unittest import mock - -from graphon.entities.pause_reason import SchedulingPause -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph_engine.event_management.event_handlers import EventHandler -from graphon.graph_engine.orchestration.dispatcher import Dispatcher -from graphon.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from graphon.graph_events import ( - GraphNodeEventBase, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from graphon.node_events import NodeRunResult -from libs.datetime_utils import naive_utc_now - - -def test_dispatcher_should_consume_remains_events_after_pause(): - event_queue = queue.Queue() - event_queue.put( - GraphNodeEventBase( - id="test", - node_id="test", - node_type=BuiltinNodeTypes.START, - ) - ) - event_handler = mock.Mock(spec=EventHandler) - execution_coordinator = mock.Mock(spec=ExecutionCoordinator) - execution_coordinator.paused.return_value = True - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=event_handler, - execution_coordinator=execution_coordinator, - ) - dispatcher._dispatcher_loop() - assert event_queue.empty() - - -class _StubExecutionCoordinator: - """Stub execution coordinator that tracks command checks.""" - - def __init__(self) -> None: - self.command_checks = 0 - self.scaling_checks = 0 - self.execution_complete = False - self.failed = False - self._paused = False - - def process_commands(self) -> None: - self.command_checks += 1 - - def check_scaling(self) -> None: - self.scaling_checks += 1 - - @property - def paused(self) -> bool: - return self._paused - - @property - def aborted(self) -> bool: - return False - - def mark_complete(self) -> None: - self.execution_complete = True - - def mark_failed(self, error: Exception) -> None: # pragma: no cover - defensive, not triggered in tests - self.failed = True - - -class _StubEventHandler: - """Minimal event handler that marks execution complete after handling an event.""" - - def __init__(self, coordinator: _StubExecutionCoordinator) -> None: - self._coordinator = coordinator - self.events = [] - - def dispatch(self, event) -> None: - self.events.append(event) - self._coordinator.mark_complete() - - -def _run_dispatcher_for_event(event) -> int: - """Run the dispatcher loop for a single event and return command check count.""" - event_queue: queue.Queue = queue.Queue() - event_queue.put(event) - - coordinator = _StubExecutionCoordinator() - event_handler = _StubEventHandler(coordinator) - - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=event_handler, - execution_coordinator=coordinator, - ) - - dispatcher._dispatcher_loop() - - return coordinator.command_checks - - -def _make_started_event() -> NodeRunStartedEvent: - return NodeRunStartedEvent( - id="start-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - node_title="Test Node", - start_at=naive_utc_now(), - ) - - -def _make_succeeded_event() -> NodeRunSucceededEvent: - return NodeRunSucceededEvent( - id="success-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - node_title="Test Node", - start_at=naive_utc_now(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ) - - -def test_dispatcher_checks_commands_during_idle_and_on_completion() -> None: - """Dispatcher polls commands when idle and after completion events.""" - started_checks = _run_dispatcher_for_event(_make_started_event()) - succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event()) - - assert started_checks == 2 - assert succeeded_checks == 3 - - -class _PauseStubEventHandler: - """Minimal event handler that marks execution complete after handling an event.""" - - def __init__(self, coordinator: _StubExecutionCoordinator) -> None: - self._coordinator = coordinator - self.events = [] - - def dispatch(self, event) -> None: - self.events.append(event) - if isinstance(event, NodeRunPauseRequestedEvent): - self._coordinator.mark_complete() - - -def test_dispatcher_drain_event_queue(): - events = [ - NodeRunStartedEvent( - id="start-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - node_title="Code", - start_at=naive_utc_now(), - ), - NodeRunPauseRequestedEvent( - id="pause-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - reason=SchedulingPause(message="test pause"), - ), - NodeRunSucceededEvent( - id="success-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - start_at=naive_utc_now(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ), - ] - - event_queue: queue.Queue = queue.Queue() - for e in events: - event_queue.put(e) - - coordinator = _StubExecutionCoordinator() - event_handler = _PauseStubEventHandler(coordinator) - - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=event_handler, - execution_coordinator=coordinator, - ) - - dispatcher._dispatcher_loop() - - # ensure all events are drained. - assert event_queue.empty() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py deleted file mode 100644 index 1510c8e595..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py +++ /dev/null @@ -1,37 +0,0 @@ -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_answer_end_with_text(): - fixture_name = "answer_end_with_text" - case = WorkflowTestCase( - fixture_name, - query="Hello, AI!", - expected_outputs={"answer": "prefixHello, AI!suffix"}, - expected_event_sequence=[ - GraphRunStartedEvent, - # Start - NodeRunStartedEvent, - # The chunks are now emitted as the Answer node processes them - # since sys.query is a special selector that gets attributed to - # the active response node - NodeRunStreamChunkEvent, # prefix - NodeRunStreamChunkEvent, # sys.query - NodeRunStreamChunkEvent, # suffix - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py deleted file mode 100644 index 6569439b56..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py +++ /dev/null @@ -1,28 +0,0 @@ -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - -LLM_NODE_ID = "1759052580454" - - -def test_answer_nodes_emit_in_order() -> None: - mock_config = ( - MockConfigBuilder() - .with_llm_response("unused default") - .with_node_output(LLM_NODE_ID, {"text": "mocked llm text"}) - .build() - ) - - expected_answer = "--- answer 1 ---\n\nfoo\n--- answer 2 ---\n\nmocked llm text\n" - - case = WorkflowTestCase( - fixture_path="test-answer-order", - query="", - expected_outputs={"answer": expected_answer}, - use_auto_mock=True, - mock_config=mock_config, - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - - assert result.success, result.error diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py deleted file mode 100644 index 05ec565def..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py +++ /dev/null @@ -1,24 +0,0 @@ -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_array_iteration_formatting_workflow(): - """ - Validate Iteration node processes [1,2,3] into formatted strings. - - Fixture description expects: - {"output": ["output: 1", "output: 2", "output: 3"]} - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="array_iteration_formatting_workflow", - inputs={}, - expected_outputs={"output": ["output: 1", "output: 2", "output: 3"]}, - description="Iteration formats numbers into strings", - use_auto_mock=True, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Iteration workflow failed: {result.error}" - assert result.actual_outputs == test_case.expected_outputs diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py deleted file mode 100644 index 5d0b37acc5..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +++ /dev/null @@ -1,392 +0,0 @@ -""" -Tests for the auto-mock system. - -This module contains tests that validate the auto-mock functionality -for workflows containing nodes that require third-party services. -""" - -import pytest - -from graphon.enums import BuiltinNodeTypes -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_simple_llm_workflow_with_auto_mock(): - """Test that a simple LLM workflow runs successfully with auto-mocking.""" - runner = TableTestRunner() - - # Create mock configuration - mock_config = MockConfigBuilder().with_llm_response("This is a test response from mocked LLM").build() - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Hello, how are you?"}, - expected_outputs={"answer": "This is a test response from mocked LLM"}, - description="Simple LLM workflow with auto-mock", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs is not None - assert "answer" in result.actual_outputs - assert result.actual_outputs["answer"] == "This is a test response from mocked LLM" - - -def test_llm_workflow_with_custom_node_output(): - """Test LLM workflow with custom output for specific node.""" - runner = TableTestRunner() - - # Create mock configuration with custom output for specific node - mock_config = MockConfig() - mock_config.set_node_outputs( - "llm_node", - { - "text": "Custom response for this specific node", - "usage": { - "prompt_tokens": 20, - "completion_tokens": 10, - "total_tokens": 30, - }, - "finish_reason": "stop", - }, - ) - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Test query"}, - expected_outputs={"answer": "Custom response for this specific node"}, - description="LLM workflow with custom node output", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs is not None - assert result.actual_outputs["answer"] == "Custom response for this specific node" - - -def test_http_tool_workflow_with_auto_mock(): - """Test workflow with HTTP request and tool nodes using auto-mock.""" - runner = TableTestRunner() - - # Create mock configuration - mock_config = MockConfig() - mock_config.set_node_outputs( - "http_node", - { - "status_code": 200, - "body": '{"key": "value", "number": 42}', - "headers": {"content-type": "application/json"}, - }, - ) - mock_config.set_node_outputs( - "tool_node", - { - "result": {"key": "value", "number": 42}, - }, - ) - - test_case = WorkflowTestCase( - fixture_path="http_request_with_json_tool_workflow", - inputs={"url": "https://api.example.com/data"}, - expected_outputs={ - "status_code": 200, - "parsed_data": {"key": "value", "number": 42}, - }, - description="HTTP and Tool workflow with auto-mock", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs is not None - assert result.actual_outputs["status_code"] == 200 - assert result.actual_outputs["parsed_data"] == {"key": "value", "number": 42} - - -def test_workflow_with_simulated_node_error(): - """Test that workflows handle simulated node errors correctly.""" - runner = TableTestRunner() - - # Create mock configuration with error - mock_config = MockConfig() - mock_config.set_node_error("llm_node", "Simulated LLM API error") - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "This should fail"}, - expected_outputs={}, # We expect failure, so no outputs - description="LLM workflow with simulated error", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - # The workflow should fail due to the simulated error - assert not result.success - assert result.error is not None - - -def test_workflow_with_mock_delays(): - """Test that mock delays work correctly.""" - runner = TableTestRunner() - - # Create mock configuration with delays - mock_config = MockConfig(simulate_delays=True) - node_config = NodeMockConfig( - node_id="llm_node", - outputs={"text": "Response after delay"}, - delay=0.1, # 100ms delay - ) - mock_config.set_node_config("llm_node", node_config) - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Test with delay"}, - expected_outputs={"answer": "Response after delay"}, - description="LLM workflow with simulated delay", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - # Execution time should be at least the delay - assert result.execution_time >= 0.1 - - -def test_mock_config_builder(): - """Test the MockConfigBuilder fluent interface.""" - config = ( - MockConfigBuilder() - .with_llm_response("LLM response") - .with_agent_response("Agent response") - .with_tool_response({"tool": "output"}) - .with_retrieval_response("Retrieval content") - .with_http_response({"status_code": 201, "body": "created"}) - .with_node_output("node1", {"output": "value"}) - .with_node_error("node2", "error message") - .with_delays(True) - .build() - ) - - assert config.default_llm_response == "LLM response" - assert config.default_agent_response == "Agent response" - assert config.default_tool_response == {"tool": "output"} - assert config.default_retrieval_response == "Retrieval content" - assert config.default_http_response == {"status_code": 201, "body": "created"} - assert config.simulate_delays is True - - node1_config = config.get_node_config("node1") - assert node1_config is not None - assert node1_config.outputs == {"output": "value"} - - node2_config = config.get_node_config("node2") - assert node2_config is not None - assert node2_config.error == "error message" - - -def test_mock_factory_node_type_detection(): - """Test that MockNodeFactory correctly identifies nodes to mock.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.runtime import GraphRuntimeState, VariablePool - - from .test_mock_factory import MockNodeFactory - - graph_init_params = build_test_graph_init_params( - workflow_id="test", - graph_config={}, - tenant_id="test", - app_id="test", - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # Test that third-party service nodes are identified for mocking - assert factory.should_mock_node(BuiltinNodeTypes.LLM) - assert factory.should_mock_node(BuiltinNodeTypes.AGENT) - assert factory.should_mock_node(BuiltinNodeTypes.TOOL) - assert factory.should_mock_node(BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL) - assert factory.should_mock_node(BuiltinNodeTypes.HTTP_REQUEST) - assert factory.should_mock_node(BuiltinNodeTypes.PARAMETER_EXTRACTOR) - assert factory.should_mock_node(BuiltinNodeTypes.DOCUMENT_EXTRACTOR) - - # Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Test that non-service nodes are not mocked - assert not factory.should_mock_node(BuiltinNodeTypes.START) - assert not factory.should_mock_node(BuiltinNodeTypes.END) - assert not factory.should_mock_node(BuiltinNodeTypes.IF_ELSE) - assert not factory.should_mock_node(BuiltinNodeTypes.VARIABLE_AGGREGATOR) - - -def test_custom_mock_handler(): - """Test using a custom handler function for mock outputs.""" - runner = TableTestRunner() - - # Custom handler that modifies output based on input - def custom_llm_handler(node) -> dict: - # In a real scenario, we could access node.graph_runtime_state.variable_pool - # to get the actual inputs - return { - "text": "Custom handler response", - "usage": { - "prompt_tokens": 5, - "completion_tokens": 3, - "total_tokens": 8, - }, - "finish_reason": "stop", - } - - mock_config = MockConfig() - node_config = NodeMockConfig( - node_id="llm_node", - custom_handler=custom_llm_handler, - ) - mock_config.set_node_config("llm_node", node_config) - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Test custom handler"}, - expected_outputs={"answer": "Custom handler response"}, - description="LLM workflow with custom handler", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs["answer"] == "Custom handler response" - - -def test_workflow_without_auto_mock(): - """Test that workflows work normally without auto-mock enabled.""" - runner = TableTestRunner() - - # This test uses the echo workflow which doesn't need external services - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "Test without mock"}, - expected_outputs={"query": "Test without mock"}, - description="Echo workflow without auto-mock", - use_auto_mock=False, # Auto-mock disabled - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs["query"] == "Test without mock" - - -def test_register_custom_mock_node(): - """Test registering a custom mock implementation for a node type.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.nodes.template_transform import TemplateTransformNode - from graphon.runtime import GraphRuntimeState, VariablePool - - from .test_mock_factory import MockNodeFactory - - # Create a custom mock for TemplateTransformNode - class MockTemplateTransformNode(TemplateTransformNode): - def _run(self): - # Custom mock implementation - pass - - graph_init_params = build_test_graph_init_params( - workflow_id="test", - graph_config={}, - tenant_id="test", - app_id="test", - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Unregister mock - factory.unregister_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - assert not factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Re-register custom mock - factory.register_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM, MockTemplateTransformNode) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - -def test_default_config_by_node_type(): - """Test setting default configurations by node type.""" - mock_config = MockConfig() - - # Set default config for all LLM nodes - mock_config.set_default_config( - BuiltinNodeTypes.LLM, - { - "default_response": "Default LLM response for all nodes", - "temperature": 0.7, - }, - ) - - # Set default config for all HTTP nodes - mock_config.set_default_config( - BuiltinNodeTypes.HTTP_REQUEST, - { - "default_status": 200, - "default_timeout": 30, - }, - ) - - llm_config = mock_config.get_default_config(BuiltinNodeTypes.LLM) - assert llm_config["default_response"] == "Default LLM response for all nodes" - assert llm_config["temperature"] == 0.7 - - http_config = mock_config.get_default_config(BuiltinNodeTypes.HTTP_REQUEST) - assert http_config["default_status"] == 200 - assert http_config["default_timeout"] == 30 - - # Non-configured node type should return empty dict - tool_config = mock_config.get_default_config(BuiltinNodeTypes.TOOL) - assert tool_config == {} - - -if __name__ == "__main__": - # Run all tests - pytest.main([__file__, "-v"]) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py deleted file mode 100644 index cefe3b8ac8..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py +++ /dev/null @@ -1,41 +0,0 @@ -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_basic_chatflow(): - fixture_name = "basic_chatflow" - mock_config = MockConfigBuilder().with_llm_response("mocked llm response").build() - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - expected_outputs={"answer": "mocked llm response"}, - expected_event_sequence=[ - GraphRunStartedEvent, - # START - NodeRunStartedEvent, - NodeRunSucceededEvent, - # LLM - NodeRunStartedEvent, - ] - + [NodeRunStreamChunkEvent] * ("mocked llm response".count(" ") + 2) - + [ - NodeRunSucceededEvent, - # ANSWER - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py deleted file mode 100644 index 01ac2d7a96..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ /dev/null @@ -1,266 +0,0 @@ -"""Test the command system for GraphEngine control.""" - -import time -from unittest.mock import MagicMock - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from graphon.entities.graph_init_params import GraphInitParams -from graphon.entities.pause_reason import SchedulingPause -from graphon.graph import Graph -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_engine.entities.commands import ( - AbortCommand, - CommandType, - PauseCommand, - UpdateVariablesCommand, - VariableUpdate, -) -from graphon.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import IntegerVariable, StringVariable - - -def test_abort_command(): - """Test that GraphEngine properly handles abort commands.""" - - # Create shared GraphRuntimeState - shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) - - # Create a minimal mock graph - mock_graph = MagicMock(spec=Graph) - mock_graph.nodes = {} - mock_graph.edges = {} - mock_graph.root_node = MagicMock() - mock_graph.root_node.id = "start" - - # Create mock nodes with required attributes - using shared runtime state - start_node = StartNode( - id="start", - config={"id": "start", "data": {"title": "start", "variables": []}}, - graph_init_params=GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ), - graph_runtime_state=shared_runtime_state, - ) - mock_graph.nodes["start"] = start_node - - # Mock graph methods - mock_graph.get_outgoing_edges = MagicMock(return_value=[]) - mock_graph.get_incoming_edges = MagicMock(return_value=[]) - - # Create command channel - command_channel = InMemoryChannel() - - # Create GraphEngine with same shared runtime state - engine = GraphEngine( - workflow_id="test_workflow", - graph=mock_graph, - graph_runtime_state=shared_runtime_state, # Use shared instance - command_channel=command_channel, - config=GraphEngineConfig(), - ) - - # Queue an abort request before starting. - engine.request_abort("Test abort") - - # Run engine and collect events - events = list(engine.run()) - - # Verify we get start and abort events - assert any(isinstance(e, GraphRunStartedEvent) for e in events) - assert any(isinstance(e, GraphRunAbortedEvent) for e in events) - - # Find the abort event and check its reason - abort_events = [e for e in events if isinstance(e, GraphRunAbortedEvent)] - assert len(abort_events) == 1 - assert abort_events[0].reason is not None - assert "aborted: test abort" in abort_events[0].reason.lower() - - -def test_redis_channel_serialization(): - """Test that Redis channel properly serializes and deserializes commands.""" - import json - from unittest.mock import MagicMock - - # Mock redis client - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None) - - from graphon.graph_engine.command_channels.redis_channel import RedisChannel - - # Create channel with a specific key - channel = RedisChannel(mock_redis, channel_key="workflow:123:commands") - - # Test sending a command - abort_command = AbortCommand(reason="Test abort") - channel.send_command(abort_command) - - # Verify redis methods were called - mock_pipeline.rpush.assert_called_once() - mock_pipeline.expire.assert_called_once() - - # Verify the serialized data - call_args = mock_pipeline.rpush.call_args - key = call_args[0][0] - command_json = call_args[0][1] - - assert key == "workflow:123:commands" - - # Verify JSON structure - command_data = json.loads(command_json) - assert command_data["command_type"] == "abort" - assert command_data["reason"] == "Test abort" - - # Test pause command serialization - pause_command = PauseCommand(reason="User requested pause") - channel.send_command(pause_command) - - assert len(mock_pipeline.rpush.call_args_list) == 2 - second_call_args = mock_pipeline.rpush.call_args_list[1] - pause_command_json = second_call_args[0][1] - pause_command_data = json.loads(pause_command_json) - assert pause_command_data["command_type"] == CommandType.PAUSE.value - assert pause_command_data["reason"] == "User requested pause" - - -def test_pause_command(): - """Test that GraphEngine properly handles pause commands.""" - - shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) - - mock_graph = MagicMock(spec=Graph) - mock_graph.nodes = {} - mock_graph.edges = {} - mock_graph.root_node = MagicMock() - mock_graph.root_node.id = "start" - - start_node = StartNode( - id="start", - config={"id": "start", "data": {"title": "start", "variables": []}}, - graph_init_params=GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ), - graph_runtime_state=shared_runtime_state, - ) - mock_graph.nodes["start"] = start_node - - mock_graph.get_outgoing_edges = MagicMock(return_value=[]) - mock_graph.get_incoming_edges = MagicMock(return_value=[]) - - command_channel = InMemoryChannel() - - engine = GraphEngine( - workflow_id="test_workflow", - graph=mock_graph, - graph_runtime_state=shared_runtime_state, - command_channel=command_channel, - config=GraphEngineConfig(), - ) - - pause_command = PauseCommand(reason="User requested pause") - command_channel.send_command(pause_command) - - events = list(engine.run()) - - assert any(isinstance(e, GraphRunStartedEvent) for e in events) - pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)] - assert len(pause_events) == 1 - assert pause_events[0].reasons == [SchedulingPause(message="User requested pause")] - - graph_execution = engine.graph_runtime_state.graph_execution - assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")] - - -def test_update_variables_command_updates_pool(): - """Test that GraphEngine updates variable pool via update variables command.""" - - shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) - shared_runtime_state.variable_pool.add(("node1", "foo"), "old value") - - mock_graph = MagicMock(spec=Graph) - mock_graph.nodes = {} - mock_graph.edges = {} - mock_graph.root_node = MagicMock() - mock_graph.root_node.id = "start" - - start_node = StartNode( - id="start", - config={"id": "start", "data": {"title": "start", "variables": []}}, - graph_init_params=GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ), - graph_runtime_state=shared_runtime_state, - ) - mock_graph.nodes["start"] = start_node - - mock_graph.get_outgoing_edges = MagicMock(return_value=[]) - mock_graph.get_incoming_edges = MagicMock(return_value=[]) - - command_channel = InMemoryChannel() - - engine = GraphEngine( - workflow_id="test_workflow", - graph=mock_graph, - graph_runtime_state=shared_runtime_state, - command_channel=command_channel, - config=GraphEngineConfig(), - ) - - update_command = UpdateVariablesCommand( - updates=[ - VariableUpdate( - value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]), - ), - VariableUpdate( - value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]), - ), - ] - ) - command_channel.send_command(update_command) - - list(engine.run()) - - updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"]) - added_new = shared_runtime_state.variable_pool.get(["node2", "bar"]) - - assert updated_existing is not None - assert updated_existing.value == "new value" - assert added_new is not None - assert added_new.value == 123 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py deleted file mode 100644 index ba9c502452..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Test suite for complex branch workflow with parallel execution and conditional routing. - -This test suite validates the behavior of a workflow that: -1. Executes nodes in parallel (IF/ELSE and LLM branches) -2. Routes based on conditional logic (query containing 'hello') -3. Handles multiple answer nodes with different outputs -""" - -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -class TestComplexBranchWorkflow: - """Test suite for complex branch workflow with parallel execution.""" - - def setup_method(self): - """Set up test environment before each test method.""" - self.runner = TableTestRunner() - self.fixture_path = "test_complex_branch" - - def test_hello_branch_with_llm(self): - """ - Test when query contains 'hello' - should trigger true branch. - Both IF/ELSE and LLM should execute in parallel. - """ - mock_text_1 = "This is a mocked LLM response for hello world" - test_cases = [ - WorkflowTestCase( - fixture_path=self.fixture_path, - query="hello world", - expected_outputs={ - "answer": f"contains 'hello'{mock_text_1}", - }, - description="Basic hello case with parallel LLM execution", - use_auto_mock=True, - mock_config=(MockConfigBuilder().with_node_output("1755502777322", {"text": mock_text_1}).build()), - ), - WorkflowTestCase( - fixture_path=self.fixture_path, - query="say hello to everyone", - expected_outputs={ - "answer": "contains 'hello'Mocked response for greeting", - }, - description="Hello in middle of sentence", - use_auto_mock=True, - mock_config=( - MockConfigBuilder() - .with_node_output("1755502777322", {"text": "Mocked response for greeting"}) - .build() - ), - ), - ] - - suite_result = self.runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test '{result.test_case.description}' failed: {result.error}" - assert result.actual_outputs - assert any(isinstance(event, GraphRunStartedEvent) for event in result.events) - assert any(isinstance(event, GraphRunSucceededEvent) for event in result.events) - - start_index = next( - idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunStartedEvent) - ) - success_index = max( - idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunSucceededEvent) - ) - assert start_index < success_index - - started_node_ids = {event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)} - assert {"1755502773326", "1755502777322"}.issubset(started_node_ids), ( - f"Branch or LLM nodes missing in events: {started_node_ids}" - ) - - assert any(isinstance(event, NodeRunStreamChunkEvent) for event in result.events), ( - "Expected streaming chunks from LLM execution" - ) - - llm_start_index = next( - idx - for idx, event in enumerate(result.events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == "1755502777322" - ) - assert any( - idx > llm_start_index and isinstance(event, NodeRunStreamChunkEvent) - for idx, event in enumerate(result.events) - ), "Streaming chunks should follow LLM node start" - - def test_non_hello_branch_with_llm(self): - """ - Test when query doesn't contain 'hello' - should trigger false branch. - LLM output should be used as the final answer. - """ - test_cases = [ - WorkflowTestCase( - fixture_path=self.fixture_path, - query="goodbye world", - expected_outputs={ - "answer": "Mocked LLM response for goodbye", - }, - description="Goodbye case - false branch with LLM output", - use_auto_mock=True, - mock_config=( - MockConfigBuilder() - .with_node_output("1755502777322", {"text": "Mocked LLM response for goodbye"}) - .build() - ), - ), - WorkflowTestCase( - fixture_path=self.fixture_path, - query="test message", - expected_outputs={ - "answer": "Mocked response for test", - }, - description="Regular message - false branch", - use_auto_mock=True, - mock_config=( - MockConfigBuilder().with_node_output("1755502777322", {"text": "Mocked response for test"}).build() - ), - ), - ] - - suite_result = self.runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test '{result.test_case.description}' failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py deleted file mode 100644 index 3851480731..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py +++ /dev/null @@ -1,220 +0,0 @@ -""" -Test for streaming output workflow behavior. - -This test validates that: -- When blocking == 1: No NodeRunStreamChunkEvent (flow through Template node) -- When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output) -""" - -from graphon.enums import BuiltinNodeTypes -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_events import ( - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_table_runner import TableTestRunner - - -def test_streaming_output_with_blocking_equals_one(): - """ - Test workflow when blocking == 1 (LLM → Template → End). - - Template node doesn't produce streaming output, so no NodeRunStreamChunkEvent should be present. - This test should FAIL according to requirements. - """ - runner = TableTestRunner() - - # Load the workflow configuration - fixture_data = runner.workflow_runner.load_fixture("conditional_streaming_vs_template_workflow") - - # Create graph from fixture with auto-mock enabled - graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - inputs={"query": "Hello, how are you?", "blocking": 1}, - use_mock_factory=True, - ) - - # Create and run the engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Execute the workflow - events = list(engine.run()) - - # Check for successful completion - success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] - assert len(success_events) > 0, "Workflow should complete successfully" - - # Check for streaming events - stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] - stream_chunk_count = len(stream_chunk_events) - - # According to requirements, we expect exactly 3 streaming events from the End node - # 1. User query - # 2. Newline - # 3. Template output (which contains the LLM response) - assert stream_chunk_count == 3, f"Expected 3 streaming events when blocking=1, but got {stream_chunk_count}" - - first_chunk, second_chunk, third_chunk = stream_chunk_events[0], stream_chunk_events[1], stream_chunk_events[2] - assert first_chunk.chunk == "Hello, how are you?", ( - f"Expected first chunk to be user input, but got {first_chunk.chunk}" - ) - assert second_chunk.chunk == "\n", f"Expected second chunk to be newline, but got {second_chunk.chunk}" - # Third chunk will be the template output with the mock LLM response - assert isinstance(third_chunk.chunk, str), f"Expected third chunk to be string, but got {type(third_chunk.chunk)}" - - # Find indices of first LLM success event and first stream chunk event - llm2_start_index = next( - ( - i - for i, e in enumerate(events) - if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM - ), - -1, - ) - first_chunk_index = next( - (i for i, e in enumerate(events) if isinstance(e, NodeRunStreamChunkEvent)), - -1, - ) - - assert first_chunk_index < llm2_start_index, ( - f"Expected first chunk before LLM2 start, but got {first_chunk_index} and {llm2_start_index}" - ) - - # Check that NodeRunStreamChunkEvent contains 'query' should has same id with Start NodeRunStartedEvent - start_node_id = graph.root_node.id - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_id == start_node_id] - assert len(start_events) == 1, f"Expected 1 start event for node {start_node_id}, but got {len(start_events)}" - start_event = start_events[0] - query_chunk_events = [e for e in stream_chunk_events if e.chunk == "Hello, how are you?"] - assert all(e.id == start_event.id for e in query_chunk_events), "Expected all query chunk events to have same id" - - # Check all Template's NodeRunStreamChunkEvent should has same id with Template's NodeRunStartedEvent - start_events = [ - e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM - ] - template_chunk_events = [e for e in stream_chunk_events if e.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM] - assert len(template_chunk_events) == 1, f"Expected 1 template chunk event, but got {len(template_chunk_events)}" - assert all(e.id in [se.id for se in start_events] for e in template_chunk_events), ( - "Expected all Template chunk events to have same id with Template's NodeRunStartedEvent" - ) - - # Check that NodeRunStreamChunkEvent contains '\n' is from the End node - end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.END] - assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}" - newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"] - assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}" - # The newline chunk should be from the End node (check node_id, not execution id) - assert all(e.node_id == end_events[0].node_id for e in newline_chunk_events), ( - "Expected all newline chunk events to be from End node" - ) - - -def test_streaming_output_with_blocking_not_equals_one(): - """ - Test workflow when blocking != 1 (LLM → End directly). - - End node should produce streaming output with NodeRunStreamChunkEvent. - This test should PASS according to requirements. - """ - runner = TableTestRunner() - - # Load the workflow configuration - fixture_data = runner.workflow_runner.load_fixture("conditional_streaming_vs_template_workflow") - - # Create graph from fixture with auto-mock enabled - graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - inputs={"query": "Hello, how are you?", "blocking": 2}, - use_mock_factory=True, - ) - - # Create and run the engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Execute the workflow - events = list(engine.run()) - - # Check for successful completion - success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] - assert len(success_events) > 0, "Workflow should complete successfully" - - # Check for streaming events - expecting streaming events - stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] - stream_chunk_count = len(stream_chunk_events) - - # This assertion should PASS according to requirements - assert stream_chunk_count > 0, f"Expected streaming events when blocking!=1, but got {stream_chunk_count}" - - # We should have at least 2 chunks (query and newline) - assert stream_chunk_count >= 2, f"Expected at least 2 streaming events, but got {stream_chunk_count}" - - first_chunk, second_chunk = stream_chunk_events[0], stream_chunk_events[1] - assert first_chunk.chunk == "Hello, how are you?", ( - f"Expected first chunk to be user input, but got {first_chunk.chunk}" - ) - assert second_chunk.chunk == "\n", f"Expected second chunk to be newline, but got {second_chunk.chunk}" - - # Find indices of first LLM success event and first stream chunk event - llm2_start_index = next( - ( - i - for i, e in enumerate(events) - if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM - ), - -1, - ) - first_chunk_index = next( - (i for i, e in enumerate(events) if isinstance(e, NodeRunStreamChunkEvent)), - -1, - ) - - assert first_chunk_index < llm2_start_index, ( - f"Expected first chunk before LLM2 start, but got {first_chunk_index} and {llm2_start_index}" - ) - - # With auto-mock, the LLM will produce mock responses - just verify we have streaming chunks - # and they are strings - for chunk_event in stream_chunk_events[2:]: - assert isinstance(chunk_event.chunk, str), f"Expected chunk to be string, but got {type(chunk_event.chunk)}" - - # Check that NodeRunStreamChunkEvent contains 'query' should has same id with Start NodeRunStartedEvent - start_node_id = graph.root_node.id - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_id == start_node_id] - assert len(start_events) == 1, f"Expected 1 start event for node {start_node_id}, but got {len(start_events)}" - start_event = start_events[0] - query_chunk_events = [e for e in stream_chunk_events if e.chunk == "Hello, how are you?"] - assert all(e.id == start_event.id for e in query_chunk_events), "Expected all query chunk events to have same id" - - # Check all LLM's NodeRunStreamChunkEvent should be from LLM nodes - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.LLM] - llm_chunk_events = [e for e in stream_chunk_events if e.node_type == BuiltinNodeTypes.LLM] - llm_node_ids = {se.node_id for se in start_events} - assert all(e.node_id in llm_node_ids for e in llm_chunk_events), ( - "Expected all LLM chunk events to be from LLM nodes" - ) - - # Check that NodeRunStreamChunkEvent contains '\n' is from the End node - end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.END] - assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}" - newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"] - assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}" - # The newline chunk should be from the End node (check node_id, not execution id) - assert all(e.node_id == end_events[0].node_id for e in newline_chunk_events), ( - "Expected all newline chunk events to be from End node" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py deleted file mode 100644 index ae7dd48bb1..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Utilities for detecting if database service is available for workflow tests. -""" - -import psycopg2 -import pytest - -from configs import dify_config - - -def is_database_available() -> bool: - """ - Check if the database service is available by attempting to connect to it. - - Returns: - True if database is available, False otherwise. - """ - try: - # Try to establish a database connection using a context manager - with psycopg2.connect( - host=dify_config.DB_HOST, - port=dify_config.DB_PORT, - database=dify_config.DB_DATABASE, - user=dify_config.DB_USERNAME, - password=dify_config.DB_PASSWORD, - connect_timeout=2, # 2 second timeout - ) as conn: - pass # Connection established and will be closed automatically - return True - except (psycopg2.OperationalError, psycopg2.Error): - return False - - -def skip_if_database_unavailable(): - """ - Pytest skip decorator that skips tests when database service is unavailable. - - Usage: - @skip_if_database_unavailable() - def test_my_workflow(): - ... - """ - return pytest.mark.skipif( - not is_database_available(), - reason="Database service is not available (connection refused or authentication failed)", - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py deleted file mode 100644 index 3ee34e86c6..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py +++ /dev/null @@ -1,72 +0,0 @@ -import queue -from datetime import datetime - -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph_engine.orchestration.dispatcher import Dispatcher -from graphon.graph_events import NodeRunSucceededEvent -from graphon.node_events import NodeRunResult - - -class StubExecutionCoordinator: - def __init__(self, paused: bool) -> None: - self._paused = paused - self.mark_complete_called = False - self.failed_error: Exception | None = None - - @property - def aborted(self) -> bool: - return False - - @property - def paused(self) -> bool: - return self._paused - - @property - def execution_complete(self) -> bool: - return False - - def check_scaling(self) -> None: - return None - - def process_commands(self) -> None: - return None - - def mark_complete(self) -> None: - self.mark_complete_called = True - - def mark_failed(self, error: Exception) -> None: - self.failed_error = error - - -class StubEventHandler: - def __init__(self) -> None: - self.events: list[object] = [] - - def dispatch(self, event: object) -> None: - self.events.append(event) - - -def test_dispatcher_drains_events_when_paused() -> None: - event_queue: queue.Queue = queue.Queue() - event = NodeRunSucceededEvent( - id="exec-1", - node_id="node-1", - node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ) - event_queue.put(event) - - handler = StubEventHandler() - coordinator = StubExecutionCoordinator(paused=True) - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=handler, - execution_coordinator=coordinator, - event_emitter=None, - ) - - dispatcher._dispatcher_loop() - - assert handler.events == [event] - assert coordinator.mark_complete_called is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py deleted file mode 100644 index ada55f3dc5..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Test case for end node without value_type field (backward compatibility). - -This test validates that end nodes work correctly even when the value_type -field is missing from the output configuration, ensuring backward compatibility -with older workflow definitions. -""" - -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_end_node_without_value_type_field(): - """ - Test that end node works without explicit value_type field. - - The fixture implements a simple workflow that: - 1. Takes a query input from start node - 2. Passes it directly to end node - 3. End node outputs the value without specifying value_type - 4. Should correctly infer the type and output the value - - This ensures backward compatibility with workflow definitions - created before value_type became a required field. - """ - fixture_name = "end_node_without_value_type_field_workflow" - - case = WorkflowTestCase( - fixture_path=fixture_name, - inputs={"query": "test query"}, - expected_outputs={"query": "test query"}, - expected_event_sequence=[ - # Graph start - GraphRunStartedEvent, - # Start node - NodeRunStartedEvent, - NodeRunStreamChunkEvent, # Start node streams the input value - NodeRunSucceededEvent, - # End node - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Graph end - GraphRunSucceededEvent, - ], - description="End node without value_type field should work correctly", - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs == {"query": "test query"}, ( - f"Expected output to be {{'query': 'test query'}}, got {result.actual_outputs}" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py deleted file mode 100644 index 95a94110d2..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Unit tests for the execution coordinator orchestration logic.""" - -from unittest.mock import MagicMock - -import pytest - -from graphon.graph_engine.command_processing.command_processor import CommandProcessor -from graphon.graph_engine.domain.graph_execution import GraphExecution -from graphon.graph_engine.graph_state_manager import GraphStateManager -from graphon.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from graphon.graph_engine.worker_management.worker_pool import WorkerPool - - -def _build_coordinator(graph_execution: GraphExecution) -> tuple[ExecutionCoordinator, MagicMock, MagicMock]: - command_processor = MagicMock(spec=CommandProcessor) - state_manager = MagicMock(spec=GraphStateManager) - worker_pool = MagicMock(spec=WorkerPool) - - coordinator = ExecutionCoordinator( - graph_execution=graph_execution, - state_manager=state_manager, - command_processor=command_processor, - worker_pool=worker_pool, - ) - return coordinator, state_manager, worker_pool - - -def test_handle_pause_stops_workers_and_clears_state() -> None: - """Paused execution should stop workers and clear executing state.""" - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - graph_execution.pause("Awaiting human input") - - coordinator, state_manager, worker_pool = _build_coordinator(graph_execution) - - coordinator.handle_pause_if_needed() - - worker_pool.stop.assert_called_once_with() - state_manager.clear_executing.assert_called_once_with() - - -def test_handle_pause_noop_when_execution_running() -> None: - """Running execution should not trigger pause handling.""" - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - - coordinator, state_manager, worker_pool = _build_coordinator(graph_execution) - - coordinator.handle_pause_if_needed() - - worker_pool.stop.assert_not_called() - state_manager.clear_executing.assert_not_called() - - -def test_has_executing_nodes_requires_pause() -> None: - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - - coordinator, _, _ = _build_coordinator(graph_execution) - - with pytest.raises(AssertionError): - coordinator.has_executing_nodes() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py deleted file mode 100644 index 51ece26d49..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ /dev/null @@ -1,770 +0,0 @@ -""" -Table-driven test framework for GraphEngine workflows. - -This file contains property-based tests and specific workflow tests. -The core test framework is in test_table_runner.py. -""" - -import time - -from hypothesis import HealthCheck, given, settings -from hypothesis import strategies as st - -from graphon.entities.base_node_data import DefaultValue, DefaultValueType -from graphon.enums import ErrorStrategy -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_events import ( - GraphRunPartialSucceededEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) - -# Import the test framework from the new module -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase - - -# Property-based fuzzing tests for the start-end workflow -@given(query_input=st.text()) -@settings(max_examples=50, deadline=30000, suppress_health_check=[HealthCheck.too_slow]) -def test_echo_workflow_property_basic_strings(query_input): - """ - Property-based test: Echo workflow should return exactly what was input. - - This tests the fundamental property that for any string input, - the start-end workflow should echo it back unchanged. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Fuzzing test with input: {repr(query_input)[:50]}...", - ) - - result = runner.run_test_case(test_case) - - # Property: The workflow should complete successfully - assert result.success, f"Workflow failed with input {repr(query_input)}: {result.error}" - - # Property: Output should equal input (echo behavior) - assert result.actual_outputs - assert result.actual_outputs == {"query": query_input}, ( - f"Echo property violated. Input: {repr(query_input)}, " - f"Expected: {repr(query_input)}, Got: {repr(result.actual_outputs.get('query'))}" - ) - - -@given(query_input=st.text(min_size=0, max_size=1000)) -@settings(max_examples=30, deadline=20000) -def test_echo_workflow_property_bounded_strings(query_input): - """ - Property-based test with size bounds to test edge cases more efficiently. - - Tests strings up to 1000 characters to balance thoroughness with performance. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Bounded fuzzing test (len={len(query_input)})", - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed with bounded input: {result.error}" - assert result.actual_outputs == {"query": query_input} - - -@given( - query_input=st.one_of( - st.text(alphabet=st.characters(whitelist_categories=["Lu", "Ll", "Nd", "Po"])), # Letters, digits, punctuation - st.text(alphabet="🎉🌟💫⭐🔥💯🚀🎯"), # Emojis - st.text(alphabet="αβγδεζηθικλμνξοπρστυφχψω"), # Greek letters - st.text(alphabet="中文测试한국어日本語العربية"), # International characters - st.just(""), # Empty string - st.just(" " * 100), # Whitespace only - st.just("\n\t\r\f\v"), # Special whitespace chars - st.just('{"json": "like", "data": [1, 2, 3]}'), # JSON-like string - st.just("SELECT * FROM users; DROP TABLE users;--"), # SQL injection attempt - st.just(""), # XSS attempt - st.just("../../etc/passwd"), # Path traversal attempt - ) -) -@settings(max_examples=40, deadline=25000) -def test_echo_workflow_property_diverse_inputs(query_input): - """ - Property-based test with diverse input types including edge cases and security payloads. - - Tests various categories of potentially problematic inputs: - - Unicode characters from different languages - - Emojis and special symbols - - Whitespace variations - - Malicious payloads (SQL injection, XSS, path traversal) - - JSON-like structures - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Diverse input fuzzing: {type(query_input).__name__}", - ) - - result = runner.run_test_case(test_case) - - # Property: System should handle all inputs gracefully (no crashes) - assert result.success, f"Workflow failed with diverse input {repr(query_input)}: {result.error}" - - # Property: Echo behavior must be preserved regardless of input type - assert result.actual_outputs == {"query": query_input} - - -@given(query_input=st.text(min_size=1000, max_size=5000)) -@settings(max_examples=10, deadline=60000) -def test_echo_workflow_property_large_inputs(query_input): - """ - Property-based test for large inputs to test memory and performance boundaries. - - Tests the system's ability to handle larger payloads efficiently. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Large input test (size: {len(query_input)} chars)", - timeout=45.0, # Longer timeout for large inputs - ) - - start_time = time.perf_counter() - result = runner.run_test_case(test_case) - execution_time = time.perf_counter() - start_time - - # Property: Large inputs should still work - assert result.success, f"Large input workflow failed: {result.error}" - - # Property: Echo behavior preserved for large inputs - assert result.actual_outputs == {"query": query_input} - - # Property: Performance should be reasonable even for large inputs - assert execution_time < 30.0, f"Large input took too long: {execution_time:.2f}s" - - -def test_echo_workflow_robustness_smoke_test(): - """ - Smoke test to ensure the basic workflow functionality works before fuzzing. - - This test uses a simple, known-good input to verify the test infrastructure - is working correctly before running the fuzzing tests. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "smoke test"}, - expected_outputs={"query": "smoke test"}, - description="Smoke test for basic functionality", - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Smoke test failed: {result.error}" - assert result.actual_outputs == {"query": "smoke test"} - assert result.execution_time > 0 - - -def test_if_else_workflow_true_branch(): - """ - Test if-else workflow when input contains 'hello' (true branch). - - Should output {"true": input_query} when query contains "hello". - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello world"}, - expected_outputs={"true": "hello world"}, - description="Basic hello case", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "say hello to everyone"}, - expected_outputs={"true": "say hello to everyone"}, - description="Hello in middle of sentence", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello"}, - expected_outputs={"true": "hello"}, - description="Just hello", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hellohello"}, - expected_outputs={"true": "hellohello"}, - description="Multiple hello occurrences", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test case '{result.test_case.description}' failed: {result.error}" - # Check that outputs contain ONLY the expected key (true branch) - assert result.actual_outputs == result.test_case.expected_outputs, ( - f"Expected only 'true' key in outputs for {result.test_case.description}. " - f"Expected: {result.test_case.expected_outputs}, Got: {result.actual_outputs}" - ) - - -def test_if_else_workflow_false_branch(): - """ - Test if-else workflow when input does not contain 'hello' (false branch). - - Should output {"false": input_query} when query does not contain "hello". - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "goodbye world"}, - expected_outputs={"false": "goodbye world"}, - description="Basic goodbye case", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hi there"}, - expected_outputs={"false": "hi there"}, - description="Simple greeting without hello", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": ""}, - expected_outputs={"false": ""}, - description="Empty string", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "test message"}, - expected_outputs={"false": "test message"}, - description="Regular message", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test case '{result.test_case.description}' failed: {result.error}" - # Check that outputs contain ONLY the expected key (false branch) - assert result.actual_outputs == result.test_case.expected_outputs, ( - f"Expected only 'false' key in outputs for {result.test_case.description}. " - f"Expected: {result.test_case.expected_outputs}, Got: {result.actual_outputs}" - ) - - -def test_if_else_workflow_edge_cases(): - """ - Test if-else workflow edge cases and case sensitivity. - - Tests various edge cases including case sensitivity, similar words, etc. - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "Hello world"}, - expected_outputs={"false": "Hello world"}, - description="Capitalized Hello (case sensitive test)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "HELLO"}, - expected_outputs={"false": "HELLO"}, - description="All caps HELLO (case sensitive test)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "helllo"}, - expected_outputs={"false": "helllo"}, - description="Typo: helllo (with extra l)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "helo"}, - expected_outputs={"false": "helo"}, - description="Typo: helo (missing l)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello123"}, - expected_outputs={"true": "hello123"}, - description="Hello with numbers", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello!@#"}, - expected_outputs={"true": "hello!@#"}, - description="Hello with special characters", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": " hello "}, - expected_outputs={"true": " hello "}, - description="Hello with surrounding spaces", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test case '{result.test_case.description}' failed: {result.error}" - # Check that outputs contain ONLY the expected key - assert result.actual_outputs == result.test_case.expected_outputs, ( - f"Expected exact match for {result.test_case.description}. " - f"Expected: {result.test_case.expected_outputs}, Got: {result.actual_outputs}" - ) - - -@given(query_input=st.text()) -@settings(max_examples=50, deadline=30000, suppress_health_check=[HealthCheck.too_slow]) -def test_if_else_workflow_property_basic_strings(query_input): - """ - Property-based test: If-else workflow should output correct branch based on 'hello' content. - - This tests the fundamental property that for any string input: - - If input contains "hello", output should be {"true": input} - - If input doesn't contain "hello", output should be {"false": input} - """ - runner = TableTestRunner() - - # Determine expected output based on whether input contains "hello" - contains_hello = "hello" in query_input - expected_key = "true" if contains_hello else "false" - expected_outputs = {expected_key: query_input} - - test_case = WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": query_input}, - expected_outputs=expected_outputs, - description=f"Property test with input: {repr(query_input)[:50]}...", - ) - - result = runner.run_test_case(test_case) - - # Property: The workflow should complete successfully - assert result.success, f"Workflow failed with input {repr(query_input)}: {result.error}" - - # Property: Output should contain ONLY the expected key with correct value - assert result.actual_outputs == expected_outputs, ( - f"If-else property violated. Input: {repr(query_input)}, " - f"Expected: {expected_outputs}, Got: {result.actual_outputs}" - ) - - -@given(query_input=st.text(min_size=0, max_size=1000)) -@settings(max_examples=30, deadline=20000) -def test_if_else_workflow_property_bounded_strings(query_input): - """ - Property-based test with size bounds for if-else workflow. - - Tests strings up to 1000 characters to balance thoroughness with performance. - """ - runner = TableTestRunner() - - contains_hello = "hello" in query_input - expected_key = "true" if contains_hello else "false" - expected_outputs = {expected_key: query_input} - - test_case = WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": query_input}, - expected_outputs=expected_outputs, - description=f"Bounded if-else test (len={len(query_input)}, contains_hello={contains_hello})", - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed with bounded input: {result.error}" - assert result.actual_outputs == expected_outputs - - -@given( - query_input=st.one_of( - st.text(alphabet=st.characters(whitelist_categories=["Lu", "Ll", "Nd", "Po"])), # Letters, digits, punctuation - st.text(alphabet="hello"), # Strings that definitely contain hello - st.text(alphabet="xyz"), # Strings that definitely don't contain hello - st.just("hello world"), # Known true case - st.just("goodbye world"), # Known false case - st.just(""), # Empty string - st.just("Hello"), # Case sensitivity test - st.just("HELLO"), # Case sensitivity test - st.just("hello" * 10), # Multiple hello occurrences - st.just("say hello to everyone"), # Hello in middle - st.text(alphabet="🎉🌟💫⭐🔥💯🚀🎯"), # Emojis - st.text(alphabet="中文测试한국어日本語العربية"), # International characters - ) -) -@settings(max_examples=40, deadline=25000) -def test_if_else_workflow_property_diverse_inputs(query_input): - """ - Property-based test with diverse input types for if-else workflow. - - Tests various categories including: - - Known true/false cases - - Case sensitivity scenarios - - Unicode characters from different languages - - Emojis and special symbols - - Multiple hello occurrences - """ - runner = TableTestRunner() - - contains_hello = "hello" in query_input - expected_key = "true" if contains_hello else "false" - expected_outputs = {expected_key: query_input} - - test_case = WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": query_input}, - expected_outputs=expected_outputs, - description=f"Diverse if-else test: {type(query_input).__name__} (contains_hello={contains_hello})", - ) - - result = runner.run_test_case(test_case) - - # Property: System should handle all inputs gracefully (no crashes) - assert result.success, f"Workflow failed with diverse input {repr(query_input)}: {result.error}" - - # Property: Correct branch logic must be preserved regardless of input type - assert result.actual_outputs == expected_outputs, ( - f"Branch logic violated. Input: {repr(query_input)}, " - f"Contains 'hello': {contains_hello}, Expected: {expected_outputs}, Got: {result.actual_outputs}" - ) - - -# Tests for the Layer system -def test_layer_system_basic(): - """Test basic layer functionality with DebugLoggingLayer.""" - from graphon.graph_engine.layers import DebugLoggingLayer - - runner = WorkflowRunner() - - # Load a simple echo workflow - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture(fixture_data, inputs={"query": "test layer system"}) - - # Create engine with layer - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Add debug logging layer - debug_layer = DebugLoggingLayer(level="DEBUG", include_inputs=True, include_outputs=True) - engine.layer(debug_layer) - - # Run workflow - events = list(engine.run()) - - # Verify events were generated - assert len(events) > 0 - assert isinstance(events[0], GraphRunStartedEvent) - assert isinstance(events[-1], GraphRunSucceededEvent) - - # Verify layer received context - assert debug_layer.graph_runtime_state is not None - assert debug_layer.command_channel is not None - - # Verify layer tracked execution stats - assert debug_layer.node_count > 0 - assert debug_layer.success_count > 0 - - -def test_layer_chaining(): - """Test chaining multiple layers.""" - from graphon.graph_engine.layers import DebugLoggingLayer, GraphEngineLayer - - # Create a custom test layer - class TestLayer(GraphEngineLayer): - def __init__(self): - super().__init__() - self.events_received = [] - self.graph_started = False - self.graph_ended = False - - def on_graph_start(self): - self.graph_started = True - - def on_event(self, event): - self.events_received.append(event.__class__.__name__) - - def on_graph_end(self, error): - self.graph_ended = True - - runner = WorkflowRunner() - - # Load workflow - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture(fixture_data, inputs={"query": "test chaining"}) - - # Create engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Chain multiple layers - test_layer = TestLayer() - debug_layer = DebugLoggingLayer(level="INFO") - - engine.layer(test_layer).layer(debug_layer) - - # Run workflow - events = list(engine.run()) - - # Verify both layers received events - assert test_layer.graph_started - assert test_layer.graph_ended - assert len(test_layer.events_received) > 0 - - # Verify debug layer also worked - assert debug_layer.node_count > 0 - - -def test_layer_error_handling(): - """Test that layer errors don't crash the engine.""" - from graphon.graph_engine.layers import GraphEngineLayer - - # Create a layer that throws errors - class FaultyLayer(GraphEngineLayer): - def on_graph_start(self): - raise RuntimeError("Intentional error in on_graph_start") - - def on_event(self, event): - raise RuntimeError("Intentional error in on_event") - - def on_graph_end(self, error): - raise RuntimeError("Intentional error in on_graph_end") - - runner = WorkflowRunner() - - # Load workflow - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture(fixture_data, inputs={"query": "test error handling"}) - - # Create engine with faulty layer - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Add faulty layer - engine.layer(FaultyLayer()) - - # Run workflow - should not crash despite layer errors - events = list(engine.run()) - - # Verify workflow still completed successfully - assert len(events) > 0 - assert isinstance(events[-1], GraphRunSucceededEvent) - assert events[-1].outputs == {"query": "test error handling"} - - -def test_event_sequence_validation(): - """Test the new event sequence validation feature.""" - from graphon.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent - - runner = TableTestRunner() - - # Test 1: Successful event sequence validation - test_case_success = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test event sequence"}, - expected_outputs={"query": "test event sequence"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, # Start node begins - NodeRunStreamChunkEvent, # Start node streaming - NodeRunSucceededEvent, # Start node completes - NodeRunStartedEvent, # End node begins - NodeRunSucceededEvent, # End node completes - GraphRunSucceededEvent, # Graph completes - ], - description="Test with correct event sequence", - ) - - result = runner.run_test_case(test_case_success) - assert result.success, f"Test should pass with correct event sequence. Error: {result.event_mismatch_details}" - assert result.event_sequence_match is True - assert result.event_mismatch_details is None - - # Test 2: Failed event sequence validation - wrong order - test_case_wrong_order = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test wrong order"}, - expected_outputs={"query": "test wrong order"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunSucceededEvent, # Wrong: expecting success before start - NodeRunStreamChunkEvent, - NodeRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - description="Test with incorrect event order", - ) - - result = runner.run_test_case(test_case_wrong_order) - assert not result.success, "Test should fail with incorrect event sequence" - assert result.event_sequence_match is False - assert result.event_mismatch_details is not None - assert "Event mismatch at position" in result.event_mismatch_details - - # Test 3: Failed event sequence validation - wrong count - test_case_wrong_count = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test wrong count"}, - expected_outputs={"query": "test wrong count"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Missing the second node's events - GraphRunSucceededEvent, - ], - description="Test with incorrect event count", - ) - - result = runner.run_test_case(test_case_wrong_count) - assert not result.success, "Test should fail with incorrect event count" - assert result.event_sequence_match is False - assert result.event_mismatch_details is not None - assert "Event count mismatch" in result.event_mismatch_details - - # Test 4: No event sequence validation (backward compatibility) - test_case_no_validation = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test no validation"}, - expected_outputs={"query": "test no validation"}, - # No expected_event_sequence provided - description="Test without event sequence validation", - ) - - result = runner.run_test_case(test_case_no_validation) - assert result.success, "Test should pass when no event sequence is provided" - assert result.event_sequence_match is None - assert result.event_mismatch_details is None - - -def test_event_sequence_validation_with_table_tests(): - """Test event sequence validation with table-driven tests.""" - from graphon.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent - - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test1"}, - expected_outputs={"query": "test1"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - description="Table test 1: Valid sequence", - ), - WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test2"}, - expected_outputs={"query": "test2"}, - # No event sequence validation for this test - description="Table test 2: No sequence validation", - ), - WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test3"}, - expected_outputs={"query": "test3"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - description="Table test 3: Valid sequence", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - # Check all tests passed - for i, result in enumerate(suite_result.results): - if i == 1: # Test 2 has no event sequence validation - assert result.event_sequence_match is None - else: - assert result.event_sequence_match is True - assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}" - - -def test_graph_run_emits_partial_success_when_node_failure_recovered(): - runner = TableTestRunner() - - fixture_data = runner.workflow_runner.load_fixture("basic_chatflow") - mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build() - - graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - query="hello", - use_mock_factory=True, - mock_config=mock_config, - ) - - llm_node = graph.nodes["llm"] - base_node_data = llm_node.node_data - base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE - base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)] - - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - events = list(engine.run()) - - assert isinstance(events[-1], GraphRunPartialSucceededEvent) - - partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent)) - assert partial_event.exceptions_count == 1 - assert partial_event.outputs.get("answer") == "fallback response" - - assert not any(isinstance(event, GraphRunSucceededEvent) for event in events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py deleted file mode 100644 index 348ceb6788..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Unit tests for GraphExecution serialization helpers.""" - -from __future__ import annotations - -import json -from collections import deque -from unittest.mock import MagicMock - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState -from graphon.graph_engine.domain import GraphExecution -from graphon.graph_engine.response_coordinator import ResponseStreamCoordinator -from graphon.graph_engine.response_coordinator.path import Path -from graphon.graph_engine.response_coordinator.session import ResponseSession -from graphon.graph_events import NodeRunStreamChunkEvent -from graphon.nodes.base.template import Template, TextSegment, VariableSegment - - -class CustomGraphExecutionError(Exception): - """Custom exception used to verify error serialization.""" - - -def test_graph_execution_serialization_round_trip() -> None: - """GraphExecution serialization restores full aggregate state.""" - # Arrange - execution = GraphExecution(workflow_id="wf-1") - execution.start() - node_a = execution.get_or_create_node_execution("node-a") - node_a.mark_started(execution_id="exec-1") - node_a.increment_retry() - node_a.mark_failed("boom") - node_b = execution.get_or_create_node_execution("node-b") - node_b.mark_skipped() - execution.fail(CustomGraphExecutionError("serialization failure")) - - # Act - serialized = execution.dumps() - payload = json.loads(serialized) - restored = GraphExecution(workflow_id="wf-1") - restored.loads(serialized) - - # Assert - assert payload["type"] == "GraphExecution" - assert payload["version"] == "1.0" - assert restored.workflow_id == "wf-1" - assert restored.started is True - assert restored.completed is True - assert restored.aborted is False - assert isinstance(restored.error, CustomGraphExecutionError) - assert str(restored.error) == "serialization failure" - assert set(restored.node_executions) == {"node-a", "node-b"} - restored_node_a = restored.node_executions["node-a"] - assert restored_node_a.state is NodeState.TAKEN - assert restored_node_a.retry_count == 1 - assert restored_node_a.execution_id == "exec-1" - assert restored_node_a.error == "boom" - restored_node_b = restored.node_executions["node-b"] - assert restored_node_b.state is NodeState.SKIPPED - assert restored_node_b.retry_count == 0 - assert restored_node_b.execution_id is None - assert restored_node_b.error is None - - -def test_graph_execution_loads_replaces_existing_state() -> None: - """loads replaces existing runtime data with serialized snapshot.""" - # Arrange - source = GraphExecution(workflow_id="wf-2") - source.start() - source_node = source.get_or_create_node_execution("node-source") - source_node.mark_taken() - serialized = source.dumps() - - target = GraphExecution(workflow_id="wf-2") - target.start() - target.abort("pre-existing abort") - temp_node = target.get_or_create_node_execution("node-temp") - temp_node.increment_retry() - temp_node.mark_failed("temp error") - - # Act - target.loads(serialized) - - # Assert - assert target.aborted is False - assert target.error is None - assert target.started is True - assert target.completed is False - assert set(target.node_executions) == {"node-source"} - restored_node = target.node_executions["node-source"] - assert restored_node.state is NodeState.TAKEN - assert restored_node.retry_count == 0 - assert restored_node.execution_id is None - assert restored_node.error is None - - -def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> None: - """ResponseStreamCoordinator serialization restores coordinator internals.""" - - template_main = Template(segments=[TextSegment(text="Hi "), VariableSegment(selector=["node-source", "text"])]) - template_secondary = Template(segments=[TextSegment(text="secondary")]) - - class DummyNode: - def __init__(self, node_id: str, template: Template, execution_type: NodeExecutionType) -> None: - self.id = node_id - self.node_type = ( - BuiltinNodeTypes.ANSWER if execution_type == NodeExecutionType.RESPONSE else BuiltinNodeTypes.LLM - ) - self.execution_type = execution_type - self.state = NodeState.UNKNOWN - self.title = node_id - self.template = template - - def blocks_variable_output(self, *_args) -> bool: - return False - - response_node1 = DummyNode("response-1", template_main, NodeExecutionType.RESPONSE) - response_node2 = DummyNode("response-2", template_main, NodeExecutionType.RESPONSE) - response_node3 = DummyNode("response-3", template_main, NodeExecutionType.RESPONSE) - source_node = DummyNode("node-source", template_secondary, NodeExecutionType.EXECUTABLE) - - class DummyGraph: - def __init__(self) -> None: - self.nodes = { - response_node1.id: response_node1, - response_node2.id: response_node2, - response_node3.id: response_node3, - source_node.id: source_node, - } - self.edges: dict[str, object] = {} - self.root_node = response_node1 - - def get_outgoing_edges(self, _node_id: str): # pragma: no cover - not exercised - return [] - - def get_incoming_edges(self, _node_id: str): # pragma: no cover - not exercised - return [] - - graph = DummyGraph() - - def fake_from_node(cls, node: DummyNode) -> ResponseSession: - return ResponseSession(node_id=node.id, template=node.template) - - monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node)) - - coordinator = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type] - coordinator._response_nodes = {"response-1", "response-2", "response-3"} - coordinator._paths_maps = { - "response-1": [Path(edges=["edge-1"])], - "response-2": [Path(edges=[])], - "response-3": [Path(edges=["edge-2", "edge-3"])], - } - - active_session = ResponseSession(node_id="response-1", template=response_node1.template) - active_session.index = 1 - coordinator._active_session = active_session - waiting_session = ResponseSession(node_id="response-2", template=response_node2.template) - coordinator._waiting_sessions = deque([waiting_session]) - pending_session = ResponseSession(node_id="response-3", template=response_node3.template) - pending_session.index = 2 - coordinator._response_sessions = {"response-3": pending_session} - - coordinator._node_execution_ids = {"response-1": "exec-1"} - event = NodeRunStreamChunkEvent( - id="exec-1", - node_id="response-1", - node_type=BuiltinNodeTypes.ANSWER, - selector=["node-source", "text"], - chunk="chunk-1", - is_final=False, - ) - coordinator._stream_buffers = {("node-source", "text"): [event]} - coordinator._stream_positions = {("node-source", "text"): 1} - coordinator._closed_streams = {("node-source", "text")} - - serialized = coordinator.dumps() - - restored = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type] - monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node)) - restored.loads(serialized) - - assert restored._response_nodes == {"response-1", "response-2", "response-3"} - assert restored._paths_maps["response-1"][0].edges == ["edge-1"] - assert restored._active_session is not None - assert restored._active_session.node_id == "response-1" - assert restored._active_session.index == 1 - waiting_restored = list(restored._waiting_sessions) - assert len(waiting_restored) == 1 - assert waiting_restored[0].node_id == "response-2" - assert waiting_restored[0].index == 0 - assert set(restored._response_sessions) == {"response-3"} - assert restored._response_sessions["response-3"].index == 2 - assert restored._node_execution_ids == {"response-1": "exec-1"} - assert ("node-source", "text") in restored._stream_buffers - restored_event = restored._stream_buffers[("node-source", "text")][0] - assert restored_event.chunk == "chunk-1" - assert restored._stream_positions[("node-source", "text")] == 1 - assert ("node-source", "text") in restored._closed_streams diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py deleted file mode 100644 index a6417822d2..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ /dev/null @@ -1,190 +0,0 @@ -import time -from collections.abc import Mapping - -from core.workflow.system_variables import build_system_variables -from graphon.entities import GraphInitParams -from graphon.enums import NodeState -from graphon.graph import Graph -from graphon.graph_engine.graph_state_manager import GraphStateManager -from graphon.graph_engine.ready_queue import InMemoryReadyQueue -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_llm_node( - *, - node_id: str, - runtime_state: GraphRuntimeState, - graph_init_params: GraphInitParams, - mock_config: MockConfig, -) -> MockLLMNode: - llm_data = LLMNodeData( - title=f"LLM {node_id}", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=f"Prompt {node_id}", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - return MockLLMNode( - id=llm_config["id"], - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - -def _build_graph(runtime_state: GraphRuntimeState) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - mock_config = MockConfig() - llm_a = _build_llm_node( - node_id="llm_a", - runtime_state=runtime_state, - graph_init_params=graph_init_params, - mock_config=mock_config, - ) - llm_b = _build_llm_node( - node_id="llm_b", - runtime_state=runtime_state, - graph_init_params=graph_init_params, - mock_config=mock_config, - ) - - end_data = EndNodeData(title="End", outputs=[], desc=None) - end_config = {"id": "end", "data": end_data.model_dump()} - end_node = EndNode( - id=end_config["id"], - config=end_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - builder = ( - Graph.new() - .add_root(start_node) - .add_node(llm_a, from_node_id="start") - .add_node(llm_b, from_node_id="start") - .add_node(end_node, from_node_id="llm_a") - ) - return builder.connect(tail="llm_b", head="end").build() - - -def _edge_state_map(graph: Graph) -> Mapping[tuple[str, str, str], NodeState]: - return {(edge.tail, edge.head, edge.source_handle): edge.state for edge in graph.edges.values()} - - -def test_runtime_state_snapshot_restores_graph_states() -> None: - runtime_state = _build_runtime_state() - graph = _build_graph(runtime_state) - runtime_state.attach_graph(graph) - - graph.nodes["llm_a"].state = NodeState.TAKEN - graph.nodes["llm_b"].state = NodeState.SKIPPED - - for edge in graph.edges.values(): - if edge.tail == "start" and edge.head == "llm_a": - edge.state = NodeState.TAKEN - elif edge.tail == "start" and edge.head == "llm_b": - edge.state = NodeState.SKIPPED - elif edge.head == "end" and edge.tail == "llm_a": - edge.state = NodeState.TAKEN - elif edge.head == "end" and edge.tail == "llm_b": - edge.state = NodeState.SKIPPED - - snapshot = runtime_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resumed_graph = _build_graph(resumed_state) - resumed_state.attach_graph(resumed_graph) - - assert resumed_graph.nodes["llm_a"].state == NodeState.TAKEN - assert resumed_graph.nodes["llm_b"].state == NodeState.SKIPPED - assert _edge_state_map(resumed_graph) == _edge_state_map(graph) - - -def test_join_readiness_uses_restored_edge_states() -> None: - runtime_state = _build_runtime_state() - graph = _build_graph(runtime_state) - runtime_state.attach_graph(graph) - - ready_queue = InMemoryReadyQueue() - state_manager = GraphStateManager(graph, ready_queue) - - for edge in graph.get_incoming_edges("end"): - if edge.tail == "llm_a": - edge.state = NodeState.TAKEN - if edge.tail == "llm_b": - edge.state = NodeState.UNKNOWN - - assert state_manager.is_node_ready("end") is False - - for edge in graph.get_incoming_edges("end"): - if edge.tail == "llm_b": - edge.state = NodeState.TAKEN - - assert state_manager.is_node_ready("end") is True - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resumed_graph = _build_graph(resumed_state) - resumed_state.attach_graph(resumed_graph) - - resumed_state_manager = GraphStateManager(resumed_graph, InMemoryReadyQueue()) - assert resumed_state_manager.is_node_ready("end") is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py deleted file mode 100644 index ca9a929591..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ /dev/null @@ -1,389 +0,0 @@ -import datetime -import time -from collections.abc import Iterable -from unittest import mock -from unittest.mock import MagicMock - -from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.graph import Graph -from graphon.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.graph_events.node import NodeRunHumanInputFormFilledEvent -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _build_branching_graph( - mock_config: MockConfig, - form_repository: HumanInputFormRepository, - graph_runtime_state: GraphRuntimeState | None = None, -) -> tuple[Graph, GraphRuntimeState]: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - if graph_runtime_state is None: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="test-execution-id", - ), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: - llm_data = LLMNodeData( - title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=prompt_text, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - llm_node = MockLLMNode( - id=node_id, - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - credentials_provider=mock.Mock(), - model_factory=mock.Mock(), - ) - return llm_node - - llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream") - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[ - UserAction(id="primary", title="Primary"), - UserAction(id="secondary", title="Secondary"), - ], - ) - - human_config = {"id": "human", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - form_repository=form_repository, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), - ) - - llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") - llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary") - - end_primary_data = EndNodeData( - title="End Primary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="primary_text", value_type=OutputVariableType.STRING, value_selector=["llm_primary", "text"] - ), - ], - desc=None, - ) - end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()} - end_primary = EndNode( - id=end_primary_config["id"], - config=end_primary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - end_secondary_data = EndNodeData( - title="End Secondary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="secondary_text", - value_type=OutputVariableType.STRING, - value_selector=["llm_secondary", "text"], - ), - ], - desc=None, - ) - end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()} - end_secondary = EndNode( - id=end_secondary_config["id"], - config=end_secondary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = ( - Graph.new() - .add_root(start_node) - .add_node(llm_initial) - .add_node(human_node) - .add_node(llm_primary, from_node_id="human", source_handle="primary") - .add_node(end_primary, from_node_id="llm_primary") - .add_node(llm_secondary, from_node_id="human", source_handle="secondary") - .add_node(end_secondary, from_node_id="llm_secondary") - .build() - ) - return graph, graph_runtime_state - - -def _expected_mock_llm_chunks(text: str) -> list[str]: - chunks: list[str] = [] - for index, word in enumerate(text.split(" ")): - chunk = word if index == 0 else f" {word}" - chunks.append(chunk) - chunks.append("") - return chunks - - -def _assert_stream_chunk_sequence( - chunk_events: Iterable[NodeRunStreamChunkEvent], - expected_nodes: list[str], - expected_chunks: list[str], -) -> None: - actual_nodes = [event.node_id for event in chunk_events] - actual_chunks = [event.chunk for event in chunk_events] - assert actual_nodes == expected_nodes - assert actual_chunks == expected_chunks - - -def test_human_input_llm_streaming_across_multiple_branches() -> None: - mock_config = MockConfig() - mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"}) - mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"}) - mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"}) - - branch_scenarios = [ - { - "handle": "primary", - "resume_llm": "llm_primary", - "end_node": "end_primary", - "expected_pre_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes - ("end_primary", ["\n"]), # literal segment emitted when end_primary session activates - ], - "expected_post_chunks": [ - ("llm_primary", _expected_mock_llm_chunks("Primary stream output")), # live stream from chosen branch - ], - }, - { - "handle": "secondary", - "resume_llm": "llm_secondary", - "end_node": "end_secondary", - "expected_pre_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes - ("end_secondary", ["\n"]), # literal segment emitted when end_secondary session activates - ], - "expected_post_chunks": [ - ("llm_secondary", _expected_mock_llm_chunks("Secondary")), # live stream from chosen branch - ], - }, - ] - - for scenario in branch_scenarios: - runner = TableTestRunner() - - mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form.return_value = None - mock_form_entity = MagicMock(spec=HumanInputFormEntity) - mock_form_entity.id = "test_form_id" - mock_form_entity.submission_token = "test_web_app_token" - mock_form_entity.recipients = [] - mock_form_entity.rendered_content = "rendered" - mock_form_entity.submitted = False - mock_create_repo.create_form.return_value = mock_form_entity - - def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]: - return _build_branching_graph(mock_config, mock_create_repo) - - initial_case = WorkflowTestCase( - description="HumanInput pause before branching decision", - graph_factory=initial_graph_factory, - expected_event_sequence=[ - GraphRunStartedEvent, # initial run: graph execution starts - NodeRunStartedEvent, # start node begins execution - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial starts streaming - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # human node begins and issues pause - NodeRunPauseRequestedEvent, # human node requests pause awaiting input - GraphRunPausedEvent, # graph run pauses awaiting resume - ], - ) - - initial_result = runner.run_test_case(initial_case) - - assert initial_result.success, initial_result.event_mismatch_details - assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events) - - pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"]) - post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"]) - expected_pre_chunk_events_in_resumption = [ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunHumanInputFormFilledEvent, - ] - - expected_resume_sequence: list[type] = ( - expected_pre_chunk_events_in_resumption - + [NodeRunStreamChunkEvent] * pre_chunk_count - + [ - NodeRunSucceededEvent, - NodeRunStartedEvent, - ] - + [NodeRunStreamChunkEvent] * post_chunk_count - + [ - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ] - ) - - mock_get_repo = MagicMock(spec=HumanInputFormRepository) - submitted_form = MagicMock(spec=HumanInputFormEntity) - submitted_form.id = mock_form_entity.id - submitted_form.submission_token = mock_form_entity.submission_token - submitted_form.recipients = [] - submitted_form.rendered_content = mock_form_entity.rendered_content - submitted_form.submitted = True - submitted_form.selected_action_id = scenario["handle"] - submitted_form.submitted_data = {} - submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - mock_get_repo.get_form.return_value = submitted_form - - def resume_graph_factory( - initial_result=initial_result, mock_get_repo=mock_get_repo - ) -> tuple[Graph, GraphRuntimeState]: - assert initial_result.graph_runtime_state is not None - serialized_runtime_state = initial_result.graph_runtime_state.dumps() - resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) - return _build_branching_graph(mock_config, mock_get_repo, resume_runtime_state) - - resume_case = WorkflowTestCase( - description=f"HumanInput resumes via {scenario['handle']} branch", - graph_factory=resume_graph_factory, - expected_event_sequence=expected_resume_sequence, - ) - - resume_result = runner.run_test_case(resume_case) - - assert resume_result.success, resume_result.event_mismatch_details - - resume_events = resume_result.events - - chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)] - assert len(chunk_events) == pre_chunk_count + post_chunk_count - - pre_chunk_events = chunk_events[:pre_chunk_count] - post_chunk_events = chunk_events[pre_chunk_count:] - - expected_pre_nodes: list[str] = [] - expected_pre_chunks: list[str] = [] - for node_id, chunks in scenario["expected_pre_chunks"]: - expected_pre_nodes.extend([node_id] * len(chunks)) - expected_pre_chunks.extend(chunks) - _assert_stream_chunk_sequence(pre_chunk_events, expected_pre_nodes, expected_pre_chunks) - - expected_post_nodes: list[str] = [] - expected_post_chunks: list[str] = [] - for node_id, chunks in scenario["expected_post_chunks"]: - expected_post_nodes.extend([node_id] * len(chunks)) - expected_post_chunks.extend(chunks) - _assert_stream_chunk_sequence(post_chunk_events, expected_post_nodes, expected_post_chunks) - - human_success_index = next( - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human" - ) - pre_indices = [ - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index - ] - expected_pre_chunk_events_count_in_resumption = len(expected_pre_chunk_events_in_resumption) - assert pre_indices == list(range(expected_pre_chunk_events_count_in_resumption, human_success_index)) - - resume_chunk_indices = [ - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"] - ] - assert resume_chunk_indices, "Expected streaming output from the selected branch" - resume_start_index = next( - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"] - ) - resume_success_index = next( - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"] - ) - assert resume_start_index < min(resume_chunk_indices) - assert max(resume_chunk_indices) < resume_success_index - - started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)] - assert started_nodes == ["human", scenario["resume_llm"], scenario["end_node"]] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py deleted file mode 100644 index c50aaafe2c..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ /dev/null @@ -1,346 +0,0 @@ -import datetime -import time -from unittest import mock -from unittest.mock import MagicMock - -from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.graph import Graph -from graphon.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.graph_events.node import NodeRunHumanInputFormFilledEvent -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _build_llm_human_llm_graph( - mock_config: MockConfig, - form_repository: HumanInputFormRepository, - graph_runtime_state: GraphRuntimeState | None = None, -) -> tuple[Graph, GraphRuntimeState]: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - if graph_runtime_state is None: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id," - ), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: - llm_data = LLMNodeData( - title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=prompt_text, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - llm_node = MockLLMNode( - id=node_id, - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - credentials_provider=mock.Mock(), - model_factory=mock.Mock(), - ) - return llm_node - - llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt") - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[ - UserAction(id="accept", title="Accept"), - UserAction(id="reject", title="Reject"), - ], - ) - - human_config = {"id": "human", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - form_repository=form_repository, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), - ) - - llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") - - end_data = EndNodeData( - title="End", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="resume_text", value_type=OutputVariableType.STRING, value_selector=["llm_resume", "text"] - ), - ], - desc=None, - ) - end_config = {"id": "end", "data": end_data.model_dump()} - end_node = EndNode( - id=end_config["id"], - config=end_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = ( - Graph.new() - .add_root(start_node) - .add_node(llm_first) - .add_node(human_node) - .add_node(llm_second, source_handle="accept") - .add_node(end_node) - .build() - ) - return graph, graph_runtime_state - - -def _expected_mock_llm_chunks(text: str) -> list[str]: - chunks: list[str] = [] - for index, word in enumerate(text.split(" ")): - chunk = word if index == 0 else f" {word}" - chunks.append(chunk) - chunks.append("") - return chunks - - -def test_human_input_llm_streaming_order_across_pause() -> None: - runner = TableTestRunner() - - initial_text = "Hello, pause" - resume_text = "Welcome back!" - - mock_config = MockConfig() - mock_config.set_node_outputs("llm_initial", {"text": initial_text}) - mock_config.set_node_outputs("llm_resume", {"text": resume_text}) - - expected_initial_sequence: list[type] = [ - GraphRunStartedEvent, # graph run begins - NodeRunStartedEvent, # start node begins - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial begins streaming - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # human node begins and requests pause - NodeRunPauseRequestedEvent, # human node pause requested - GraphRunPausedEvent, # graph run pauses awaiting resume - ] - - mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form.return_value = None - mock_form_entity = MagicMock(spec=HumanInputFormEntity) - mock_form_entity.id = "test_form_id" - mock_form_entity.submission_token = "test_web_app_token" - mock_form_entity.recipients = [] - mock_form_entity.rendered_content = "rendered" - mock_form_entity.submitted = False - mock_create_repo.create_form.return_value = mock_form_entity - - def graph_factory() -> tuple[Graph, GraphRuntimeState]: - return _build_llm_human_llm_graph(mock_config, mock_create_repo) - - initial_case = WorkflowTestCase( - description="HumanInput pause preserves LLM streaming order", - graph_factory=graph_factory, - expected_event_sequence=expected_initial_sequence, - ) - - initial_result = runner.run_test_case(initial_case) - - assert initial_result.success, initial_result.event_mismatch_details - - initial_events = initial_result.events - initial_chunks = _expected_mock_llm_chunks(initial_text) - - initial_stream_chunk_events = [event for event in initial_events if isinstance(event, NodeRunStreamChunkEvent)] - assert initial_stream_chunk_events == [] - - pause_index = next(i for i, event in enumerate(initial_events) if isinstance(event, GraphRunPausedEvent)) - llm_succeeded_index = next( - i - for i, event in enumerate(initial_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_initial" - ) - assert llm_succeeded_index < pause_index - - graph_runtime_state = initial_result.graph_runtime_state - graph = initial_result.graph - assert graph_runtime_state is not None - assert graph is not None - - coordinator = graph_runtime_state.response_coordinator - stream_buffers = coordinator._stream_buffers # Tests may access internals for assertions - assert ("llm_initial", "text") in stream_buffers - initial_stream_chunks = [event.chunk for event in stream_buffers[("llm_initial", "text")]] - assert initial_stream_chunks == initial_chunks - assert ("llm_resume", "text") not in stream_buffers - - resume_chunks = _expected_mock_llm_chunks(resume_text) - expected_resume_sequence: list[type] = [ - GraphRunStartedEvent, # resumed graph run begins - NodeRunStartedEvent, # human node restarts - # Form Filled should be generated first, then the node execution ends and stream chunk is generated. - NodeRunHumanInputFormFilledEvent, - NodeRunStreamChunkEvent, # cached llm_initial chunk 1 - NodeRunStreamChunkEvent, # cached llm_initial chunk 2 - NodeRunStreamChunkEvent, # cached llm_initial final chunk - NodeRunStreamChunkEvent, # end node emits combined template separator - NodeRunSucceededEvent, # human node finishes instantly after input - NodeRunStartedEvent, # llm_resume begins streaming - NodeRunStreamChunkEvent, # llm_resume chunk 1 - NodeRunStreamChunkEvent, # llm_resume chunk 2 - NodeRunStreamChunkEvent, # llm_resume final chunk - NodeRunSucceededEvent, # llm_resume completes streaming - NodeRunStartedEvent, # end node starts - NodeRunSucceededEvent, # end node finishes - GraphRunSucceededEvent, # graph run succeeds after resume - ] - - mock_get_repo = MagicMock(spec=HumanInputFormRepository) - submitted_form = MagicMock(spec=HumanInputFormEntity) - submitted_form.id = mock_form_entity.id - submitted_form.submission_token = mock_form_entity.submission_token - submitted_form.recipients = [] - submitted_form.rendered_content = mock_form_entity.rendered_content - submitted_form.submitted = True - submitted_form.selected_action_id = "accept" - submitted_form.submitted_data = {} - submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - mock_get_repo.get_form.return_value = submitted_form - - def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]: - # restruct the graph runtime state - serialized_runtime_state = initial_result.graph_runtime_state.dumps() - resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) - return _build_llm_human_llm_graph( - mock_config, - mock_get_repo, - resume_runtime_state, - ) - - resume_case = WorkflowTestCase( - description="HumanInput resume continues LLM streaming order", - graph_factory=resume_graph_factory, - expected_event_sequence=expected_resume_sequence, - ) - - resume_result = runner.run_test_case(resume_case) - - assert resume_result.success, resume_result.event_mismatch_details - - resume_events = resume_result.events - - success_index = next(i for i, event in enumerate(resume_events) if isinstance(event, GraphRunSucceededEvent)) - llm_resume_succeeded_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume" - ) - assert llm_resume_succeeded_index < success_index - - resume_chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)] - assert [event.node_id for event in resume_chunk_events[:3]] == ["llm_initial"] * 3 - assert [event.chunk for event in resume_chunk_events[:3]] == initial_chunks - assert resume_chunk_events[3].node_id == "end" - assert resume_chunk_events[3].chunk == "\n" - assert [event.node_id for event in resume_chunk_events[4:]] == ["llm_resume"] * 3 - assert [event.chunk for event in resume_chunk_events[4:]] == resume_chunks - - human_success_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human" - ) - cached_chunk_indices = [ - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id in {"llm_initial", "end"} - ] - assert all(index < human_success_index for index in cached_chunk_indices) - - llm_resume_start_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == "llm_resume" - ) - llm_resume_success_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume" - ) - llm_resume_chunk_indices = [ - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == "llm_resume" - ] - assert llm_resume_chunk_indices - first_resume_chunk_index = min(llm_resume_chunk_indices) - last_resume_chunk_index = max(llm_resume_chunk_indices) - assert llm_resume_start_index < first_resume_chunk_index - assert last_resume_chunk_index < llm_resume_success_index - - started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)] - assert started_nodes == ["human", "llm_resume", "end"] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py deleted file mode 100644 index 246df45d5f..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ /dev/null @@ -1,324 +0,0 @@ -import time -from unittest import mock - -from core.workflow.system_variables import build_system_variables -from graphon.graph import Graph -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.if_else.entities import IfElseNodeData -from graphon.nodes.if_else.if_else_node import IfElseNode -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.utils.condition.entities import Condition -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - graph_config=graph_config, - user_from="account", - invoke_from="debugger", - ) - - variable_pool = VariablePool( - system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"), - user_inputs={}, - conversation_variables=[], - ) - variable_pool.add(("branch", "value"), branch_value) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: - llm_data = LLMNodeData( - title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=prompt_text, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - llm_node = MockLLMNode( - id=node_id, - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - credentials_provider=mock.Mock(), - model_factory=mock.Mock(), - ) - return llm_node - - llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream") - - if_else_data = IfElseNodeData( - title="IfElse", - cases=[ - IfElseNodeData.Case( - case_id="primary", - logical_operator="and", - conditions=[ - Condition(variable_selector=["branch", "value"], comparison_operator="is", value="primary") - ], - ), - IfElseNodeData.Case( - case_id="secondary", - logical_operator="and", - conditions=[ - Condition(variable_selector=["branch", "value"], comparison_operator="is", value="secondary") - ], - ), - ], - ) - if_else_config = {"id": "if_else", "data": if_else_data.model_dump()} - if_else_node = IfElseNode( - id=if_else_config["id"], - config=if_else_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") - llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary") - - end_primary_data = EndNodeData( - title="End Primary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="primary_text", value_type=OutputVariableType.STRING, value_selector=["llm_primary", "text"] - ), - ], - desc=None, - ) - end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()} - end_primary = EndNode( - id=end_primary_config["id"], - config=end_primary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - end_secondary_data = EndNodeData( - title="End Secondary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="secondary_text", - value_type=OutputVariableType.STRING, - value_selector=["llm_secondary", "text"], - ), - ], - desc=None, - ) - end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()} - end_secondary = EndNode( - id=end_secondary_config["id"], - config=end_secondary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = ( - Graph.new() - .add_root(start_node) - .add_node(llm_initial) - .add_node(if_else_node) - .add_node(llm_primary, from_node_id="if_else", source_handle="primary") - .add_node(end_primary, from_node_id="llm_primary") - .add_node(llm_secondary, from_node_id="if_else", source_handle="secondary") - .add_node(end_secondary, from_node_id="llm_secondary") - .build() - ) - return graph, graph_runtime_state - - -def _expected_mock_llm_chunks(text: str) -> list[str]: - chunks: list[str] = [] - for index, word in enumerate(text.split(" ")): - chunk = word if index == 0 else f" {word}" - chunks.append(chunk) - chunks.append("") - return chunks - - -def test_if_else_llm_streaming_order() -> None: - mock_config = MockConfig() - mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"}) - mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"}) - mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"}) - - scenarios = [ - { - "branch": "primary", - "resume_llm": "llm_primary", - "end_node": "end_primary", - "expected_sequence": [ - GraphRunStartedEvent, # graph run begins - NodeRunStartedEvent, # start node begins execution - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial starts and streams - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # if_else evaluates conditions - NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed - NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed - NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed - NodeRunStreamChunkEvent, # template literal newline emitted - NodeRunSucceededEvent, # if_else completes branch selection - NodeRunStartedEvent, # llm_primary begins streaming - NodeRunStreamChunkEvent, # llm_primary chunk 1 - NodeRunStreamChunkEvent, # llm_primary chunk 2 - NodeRunStreamChunkEvent, # llm_primary chunk 3 - NodeRunStreamChunkEvent, # llm_primary final chunk - NodeRunSucceededEvent, # llm_primary completes streaming - NodeRunStartedEvent, # end_primary node starts - NodeRunSucceededEvent, # end_primary finishes aggregation - GraphRunSucceededEvent, # graph run succeeds - ], - "expected_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), - ("end_primary", ["\n"]), - ("llm_primary", _expected_mock_llm_chunks("Primary stream output")), - ], - }, - { - "branch": "secondary", - "resume_llm": "llm_secondary", - "end_node": "end_secondary", - "expected_sequence": [ - GraphRunStartedEvent, # graph run begins - NodeRunStartedEvent, # start node begins execution - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial starts and streams - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # if_else evaluates conditions - NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed - NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed - NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed - NodeRunStreamChunkEvent, # template literal newline emitted - NodeRunSucceededEvent, # if_else completes branch selection - NodeRunStartedEvent, # llm_secondary begins streaming - NodeRunStreamChunkEvent, # llm_secondary chunk 1 - NodeRunStreamChunkEvent, # llm_secondary final chunk - NodeRunSucceededEvent, # llm_secondary completes - NodeRunStartedEvent, # end_secondary node starts - NodeRunSucceededEvent, # end_secondary finishes aggregation - GraphRunSucceededEvent, # graph run succeeds - ], - "expected_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), - ("end_secondary", ["\n"]), - ("llm_secondary", _expected_mock_llm_chunks("Secondary")), - ], - }, - ] - - for scenario in scenarios: - runner = TableTestRunner() - - def graph_factory( - branch_value: str = scenario["branch"], - cfg: MockConfig = mock_config, - ) -> tuple[Graph, GraphRuntimeState]: - return _build_if_else_graph(branch_value, cfg) - - test_case = WorkflowTestCase( - description=f"IfElse streaming via {scenario['branch']} branch", - graph_factory=graph_factory, - expected_event_sequence=scenario["expected_sequence"], - ) - - result = runner.run_test_case(test_case) - - assert result.success, result.event_mismatch_details - - chunk_events = [event for event in result.events if isinstance(event, NodeRunStreamChunkEvent)] - expected_nodes: list[str] = [] - expected_chunks: list[str] = [] - for node_id, chunks in scenario["expected_chunks"]: - expected_nodes.extend([node_id] * len(chunks)) - expected_chunks.extend(chunks) - assert [event.node_id for event in chunk_events] == expected_nodes - assert [event.chunk for event in chunk_events] == expected_chunks - - branch_node_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == "if_else" - ) - branch_success_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "if_else" - ) - pre_branch_chunk_indices = [ - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStreamChunkEvent) and index < branch_success_index - ] - assert len(pre_branch_chunk_indices) == len(_expected_mock_llm_chunks("Initial stream")) + 1 - assert min(pre_branch_chunk_indices) == branch_node_index + 1 - assert max(pre_branch_chunk_indices) < branch_success_index - - resume_chunk_indices = [ - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"] - ] - assert resume_chunk_indices - resume_start_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"] - ) - resume_success_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"] - ) - assert resume_start_index < min(resume_chunk_indices) - assert max(resume_chunk_indices) < resume_success_index - - started_nodes = [event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)] - assert started_nodes == ["start", "llm_initial", "if_else", scenario["resume_llm"], scenario["end_node"]] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py b/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py deleted file mode 100644 index b9bf4be13a..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py +++ /dev/null @@ -1,126 +0,0 @@ -""" -Test cases for the Iteration node's flatten_output functionality. - -This module tests the iteration node's ability to: -1. Flatten array outputs when flatten_output=True (default) -2. Preserve nested array structure when flatten_output=False -""" - -from .test_database_utils import skip_if_database_unavailable -from .test_mock_config import MockConfigBuilder, NodeMockConfig -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _create_iteration_mock_config(): - """Helper to create a mock config for iteration tests.""" - - def code_inner_handler(node): - pool = node.graph_runtime_state.variable_pool - item_seg = pool.get(["iteration_node", "item"]) - if item_seg is not None: - item = item_seg.to_object() - return {"result": [item, item * 2]} - # This fallback is likely unreachable, but if it is, - # it doesn't simulate iteration with different values as the comment suggests. - return {"result": [1, 2]} - - return ( - MockConfigBuilder() - .with_node_output("code_node", {"result": [1, 2, 3]}) - .with_node_config(NodeMockConfig(node_id="code_inner_node", custom_handler=code_inner_handler)) - .build() - ) - - -@skip_if_database_unavailable() -def test_iteration_with_flatten_output_enabled(): - """ - Test iteration node with flatten_output=True (default behavior). - - The fixture implements an iteration that: - 1. Iterates over [1, 2, 3] - 2. For each item, outputs [item, item*2] - 3. With flatten_output=True, should output [1, 2, 2, 4, 3, 6] - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="iteration_flatten_output_enabled_workflow", - inputs={}, - expected_outputs={"output": [1, 2, 2, 4, 3, 6]}, - description="Iteration with flatten_output=True flattens nested arrays", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs is not None, "Should have outputs" - assert result.actual_outputs == {"output": [1, 2, 2, 4, 3, 6]}, ( - f"Expected flattened output [1, 2, 2, 4, 3, 6], got {result.actual_outputs}" - ) - - -@skip_if_database_unavailable() -def test_iteration_with_flatten_output_disabled(): - """ - Test iteration node with flatten_output=False. - - The fixture implements an iteration that: - 1. Iterates over [1, 2, 3] - 2. For each item, outputs [item, item*2] - 3. With flatten_output=False, should output [[1, 2], [2, 4], [3, 6]] - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="iteration_flatten_output_disabled_workflow", - inputs={}, - expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]}, - description="Iteration with flatten_output=False preserves nested structure", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs is not None, "Should have outputs" - assert result.actual_outputs == {"output": [[1, 2], [2, 4], [3, 6]]}, ( - f"Expected nested output [[1, 2], [2, 4], [3, 6]], got {result.actual_outputs}" - ) - - -@skip_if_database_unavailable() -def test_iteration_flatten_output_comparison(): - """ - Run both flatten_output configurations in parallel to verify the difference. - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="iteration_flatten_output_enabled_workflow", - inputs={}, - expected_outputs={"output": [1, 2, 2, 4, 3, 6]}, - description="flatten_output=True: Flattened output", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ), - WorkflowTestCase( - fixture_path="iteration_flatten_output_disabled_workflow", - inputs={}, - expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]}, - description="flatten_output=False: Nested output", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ), - ] - - suite_result = runner.run_table_tests(test_cases, parallel=True) - - # Assert all tests passed - assert suite_result.passed_tests == 2, f"Expected 2 passed tests, got {suite_result.passed_tests}" - assert suite_result.failed_tests == 0, f"Expected 0 failed tests, got {suite_result.failed_tests}" - assert suite_result.success_rate == 100.0, f"Expected 100% success rate, got {suite_result.success_rate}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py deleted file mode 100644 index 821da46b76..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Test case for loop with inner answer output error scenario. - -This test validates the behavior of a loop containing an answer node -inside the loop that may produce output errors. -""" - -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_loop_contains_answer(): - """ - Test loop with inner answer node that may have output errors. - - The fixture implements a loop that: - 1. Iterates 4 times (index 0-3) - 2. Contains an inner answer node that outputs index and item values - 3. Has a break condition when index equals 4 - 4. Tests error handling for answer nodes within loops - """ - fixture_name = "loop_contains_answer" - mock_config = MockConfigBuilder().build() - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - query="1", - expected_outputs={"answer": "1\n2\n1 + 2"}, - expected_event_sequence=[ - # Graph start - GraphRunStartedEvent, - # Start - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Loop start - NodeRunStartedEvent, - NodeRunLoopStartedEvent, - # Variable assigner - NodeRunStartedEvent, - NodeRunVariableUpdatedEvent, - NodeRunStreamChunkEvent, # 1 - NodeRunStreamChunkEvent, # \n - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Loop next - NodeRunLoopNextEvent, - # Variable assigner - NodeRunStartedEvent, - NodeRunVariableUpdatedEvent, - NodeRunStreamChunkEvent, # 2 - NodeRunStreamChunkEvent, # \n - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Loop end - NodeRunLoopSucceededEvent, - NodeRunStreamChunkEvent, # 1 - NodeRunStreamChunkEvent, # + - NodeRunStreamChunkEvent, # 2 - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Graph end - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py deleted file mode 100644 index ad8d777ea6..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -Test cases for the Loop node functionality using TableTestRunner. - -This module tests the loop node's ability to: -1. Execute iterations with loop variables -2. Handle break conditions correctly -3. Update and propagate loop variables between iterations -4. Output the final loop variable value -""" - -from tests.unit_tests.core.workflow.graph_engine.test_table_runner import ( - TableTestRunner, - WorkflowTestCase, -) - - -def test_loop_with_break_condition(): - """ - Test loop node with break condition. - - The increment_loop_with_break_condition_workflow.yml fixture implements a loop that: - 1. Starts with num=1 - 2. Increments num by 1 each iteration - 3. Breaks when num >= 5 - 4. Should output {"num": 5} - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="increment_loop_with_break_condition_workflow", - inputs={}, # No inputs needed for this test - expected_outputs={"num": 5}, - description="Loop with break condition when num >= 5", - ) - - result = runner.run_test_case(test_case) - - # Assert the test passed - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs is not None, "Should have outputs" - assert result.actual_outputs == {"num": 5}, f"Expected {{'num': 5}}, got {result.actual_outputs}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py deleted file mode 100644 index 4a60c7769c..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py +++ /dev/null @@ -1,72 +0,0 @@ -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_loop_with_tool(): - fixture_name = "search_dify_from_2023_to_2025" - mock_config = ( - MockConfigBuilder() - .with_tool_response( - { - "text": "mocked search result", - } - ) - .build() - ) - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - expected_outputs={ - "answer": """- mocked search result -- mocked search result""" - }, - expected_event_sequence=[ - GraphRunStartedEvent, - # START - NodeRunStartedEvent, - NodeRunSucceededEvent, - # LOOP START - NodeRunStartedEvent, - NodeRunLoopStartedEvent, - # 2023 - NodeRunStartedEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunVariableUpdatedEvent, - NodeRunVariableUpdatedEvent, - NodeRunSucceededEvent, - NodeRunLoopNextEvent, - # 2024 - NodeRunStartedEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunVariableUpdatedEvent, - NodeRunVariableUpdatedEvent, - NodeRunSucceededEvent, - # LOOP END - NodeRunLoopSucceededEvent, - NodeRunStreamChunkEvent, # loop.res - NodeRunSucceededEvent, - # ANSWER - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py deleted file mode 100644 index c511548749..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py +++ /dev/null @@ -1,281 +0,0 @@ -""" -Example demonstrating the auto-mock system for testing workflows. - -This example shows how to test workflows with third-party service nodes -without making actual API calls. -""" - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def example_test_llm_workflow(): - """ - Example: Testing a workflow with an LLM node. - - This demonstrates how to test a workflow that uses an LLM service - without making actual API calls to OpenAI, Anthropic, etc. - """ - print("\n=== Example: Testing LLM Workflow ===\n") - - # Initialize the test runner - runner = TableTestRunner() - - # Configure mock responses - mock_config = MockConfigBuilder().with_llm_response("I'm a helpful AI assistant. How can I help you today?").build() - - # Define the test case - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Hello, AI!"}, - expected_outputs={"answer": "I'm a helpful AI assistant. How can I help you today?"}, - description="Testing LLM workflow with mocked response", - use_auto_mock=True, # Enable auto-mocking - mock_config=mock_config, - ) - - # Run the test - result = runner.run_test_case(test_case) - - if result.success: - print("✅ Test passed!") - print(f" Input: {test_case.inputs['query']}") - print(f" Output: {result.actual_outputs['answer']}") - print(f" Execution time: {result.execution_time:.2f}s") - else: - print(f"❌ Test failed: {result.error}") - - return result.success - - -def example_test_with_custom_outputs(): - """ - Example: Testing with custom outputs for specific nodes. - - This shows how to provide different mock outputs for specific node IDs, - useful when testing complex workflows with multiple LLM/tool nodes. - """ - print("\n=== Example: Custom Node Outputs ===\n") - - runner = TableTestRunner() - - # Configure mock with specific outputs for different nodes - mock_config = MockConfigBuilder().build() - - # Set custom output for a specific LLM node - mock_config.set_node_outputs( - "llm_node", - { - "text": "This is a custom response for the specific LLM node", - "usage": { - "prompt_tokens": 50, - "completion_tokens": 20, - "total_tokens": 70, - }, - "finish_reason": "stop", - }, - ) - - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Tell me about custom outputs"}, - expected_outputs={"answer": "This is a custom response for the specific LLM node"}, - description="Testing with custom node outputs", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if result.success: - print("✅ Test with custom outputs passed!") - print(f" Custom output: {result.actual_outputs['answer']}") - else: - print(f"❌ Test failed: {result.error}") - - return result.success - - -def example_test_http_and_tool_workflow(): - """ - Example: Testing a workflow with HTTP request and tool nodes. - - This demonstrates mocking external HTTP calls and tool executions. - """ - print("\n=== Example: HTTP and Tool Workflow ===\n") - - runner = TableTestRunner() - - # Configure mocks for HTTP and Tool nodes - mock_config = MockConfigBuilder().build() - - # Mock HTTP response - mock_config.set_node_outputs( - "http_node", - { - "status_code": 200, - "body": '{"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}', - "headers": {"content-type": "application/json"}, - }, - ) - - # Mock tool response (e.g., JSON parser) - mock_config.set_node_outputs( - "tool_node", - { - "result": {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}, - }, - ) - - test_case = WorkflowTestCase( - fixture_path="http-tool-workflow", - inputs={"url": "https://api.example.com/users"}, - expected_outputs={ - "status_code": 200, - "parsed_data": {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}, - }, - description="Testing HTTP and Tool workflow", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if result.success: - print("✅ HTTP and Tool workflow test passed!") - print(f" HTTP Status: {result.actual_outputs['status_code']}") - print(f" Parsed Data: {result.actual_outputs['parsed_data']}") - else: - print(f"❌ Test failed: {result.error}") - - return result.success - - -def example_test_error_simulation(): - """ - Example: Simulating errors in specific nodes. - - This shows how to test error handling in workflows by simulating - failures in specific nodes. - """ - print("\n=== Example: Error Simulation ===\n") - - runner = TableTestRunner() - - # Configure mock to simulate an error - mock_config = MockConfigBuilder().build() - mock_config.set_node_error("llm_node", "API rate limit exceeded") - - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "This will fail"}, - expected_outputs={}, # We expect failure - description="Testing error handling", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if not result.success: - print("✅ Error simulation worked as expected!") - print(f" Simulated error: {result.error}") - else: - print("❌ Expected failure but test succeeded") - - return not result.success # Success means we got the expected error - - -def example_test_with_delays(): - """ - Example: Testing with simulated execution delays. - - This demonstrates how to simulate realistic execution times - for performance testing. - """ - print("\n=== Example: Simulated Delays ===\n") - - runner = TableTestRunner() - - # Configure mock with delays - mock_config = ( - MockConfigBuilder() - .with_delays(True) # Enable delay simulation - .with_llm_response("Response after delay") - .build() - ) - - # Add specific delay for the LLM node - from .test_mock_config import NodeMockConfig - - node_config = NodeMockConfig( - node_id="llm_node", - outputs={"text": "Response after delay"}, - delay=0.5, # 500ms delay - ) - mock_config.set_node_config("llm_node", node_config) - - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Test with delay"}, - expected_outputs={"answer": "Response after delay"}, - description="Testing with simulated delays", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if result.success: - print("✅ Delay simulation test passed!") - print(f" Execution time: {result.execution_time:.2f}s") - print(" (Should be >= 0.5s due to simulated delay)") - else: - print(f"❌ Test failed: {result.error}") - - return result.success and result.execution_time >= 0.5 - - -def run_all_examples(): - """Run all example tests.""" - print("\n" + "=" * 50) - print("AUTO-MOCK SYSTEM EXAMPLES") - print("=" * 50) - - examples = [ - example_test_llm_workflow, - example_test_with_custom_outputs, - example_test_http_and_tool_workflow, - example_test_error_simulation, - example_test_with_delays, - ] - - results = [] - for example in examples: - try: - results.append(example()) - except Exception as e: - print(f"\n❌ Example failed with exception: {e}") - results.append(False) - - print("\n" + "=" * 50) - print("SUMMARY") - print("=" * 50) - - passed = sum(results) - total = len(results) - print(f"\n✅ Passed: {passed}/{total}") - - if passed == total: - print("\n🎉 All examples passed successfully!") - else: - print(f"\n⚠️ {total - passed} example(s) failed") - - return passed == total - - -if __name__ == "__main__": - import sys - - success = run_all_examples() - sys.exit(0 if success else 1) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 76b2984a4b..88989db856 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -7,11 +7,12 @@ requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request from typing import TYPE_CHECKING, Any -from core.workflow.node_factory import DifyNodeFactory from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node +from core.workflow.node_factory import DifyNodeFactory + from .test_mock_nodes import ( MockAgentNode, MockCodeNode, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py deleted file mode 100644 index aff479104f..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ /dev/null @@ -1,199 +0,0 @@ -""" -Simple test to verify MockNodeFactory works with iteration nodes. -""" - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from graphon.enums import BuiltinNodeTypes -from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder -from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory - - -def test_mock_factory_registers_iteration_node(): - """Test that MockNodeFactory has iteration node registered.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create a MockNodeFactory instance - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={"nodes": [], "edges": []}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # Check that iteration node is registered - assert BuiltinNodeTypes.ITERATION in factory._mock_node_types - print("✓ Iteration node is registered in MockNodeFactory") - - # Check that loop node is registered - assert BuiltinNodeTypes.LOOP in factory._mock_node_types - print("✓ Loop node is registered in MockNodeFactory") - - # Check the class types - from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode, MockLoopNode - - assert factory._mock_node_types[BuiltinNodeTypes.ITERATION] == MockIterationNode - print("✓ Iteration node maps to MockIterationNode class") - - assert factory._mock_node_types[BuiltinNodeTypes.LOOP] == MockLoopNode - print("✓ Loop node maps to MockLoopNode class") - - -def test_mock_iteration_node_preserves_config(): - """Test that MockIterationNode preserves mock configuration.""" - - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode - - # Create mock config - mock_config = MockConfigBuilder().with_llm_response("Test response").build() - - # Create minimal graph init params - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={"nodes": [], "edges": []}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - - # Create minimal runtime state - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - - # Create mock iteration node - node_config = { - "id": "iter1", - "data": { - "type": "iteration", - "title": "Test", - "iterator_selector": ["start", "items"], - "output_selector": ["node", "text"], - "start_node_id": "node1", - }, - } - - mock_node = MockIterationNode( - id="iter1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Verify the mock config is preserved - assert mock_node.mock_config == mock_config - print("✓ MockIterationNode preserves mock configuration") - - # Check that _create_graph_engine method exists and is overridden - assert hasattr(mock_node, "_create_graph_engine") - assert MockIterationNode._create_graph_engine != MockIterationNode.__bases__[1]._create_graph_engine - print("✓ MockIterationNode overrides _create_graph_engine method") - - -def test_mock_loop_node_preserves_config(): - """Test that MockLoopNode preserves mock configuration.""" - - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode - - # Create mock config - mock_config = MockConfigBuilder().with_http_response({"status": 200}).build() - - # Create minimal graph init params - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={"nodes": [], "edges": []}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - - # Create minimal runtime state - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - - # Create mock loop node - node_config = { - "id": "loop1", - "data": { - "type": "loop", - "title": "Test", - "loop_count": 3, - "start_node_id": "node1", - "loop_variables": [], - "outputs": {}, - "break_conditions": [], - "logical_operator": "and", - }, - } - - mock_node = MockLoopNode( - id="loop1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Verify the mock config is preserved - assert mock_node.mock_config == mock_config - print("✓ MockLoopNode preserves mock configuration") - - # Check that _create_graph_engine method exists and is overridden - assert hasattr(mock_node, "_create_graph_engine") - assert MockLoopNode._create_graph_engine != MockLoopNode.__bases__[1]._create_graph_engine - print("✓ MockLoopNode overrides _create_graph_engine method") - - -if __name__ == "__main__": - test_mock_factory_registers_iteration_node() - test_mock_iteration_node_preserves_config() - test_mock_loop_node_preserves_config() - print("\n✅ All tests passed! MockNodeFactory now supports iteration and loop nodes.") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 971b9b2bbf..8b7fbd1b30 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -10,10 +10,6 @@ from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock -from core.model_manager import ModelInstance -from core.workflow.node_runtime import DifyToolNodeRuntime -from core.workflow.nodes.agent import AgentNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent @@ -31,6 +27,11 @@ from graphon.nodes.template_transform import TemplateTransformNode from graphon.nodes.tool import ToolNode from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError +from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.nodes.agent import AgentNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode + if TYPE_CHECKING: from graphon.entities import GraphInitParams from graphon.runtime import GraphRuntimeState diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py deleted file mode 100644 index 15f6f51398..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ /dev/null @@ -1,670 +0,0 @@ -""" -Test cases for Mock Template Transform and Code nodes. - -This module tests the functionality of MockTemplateTransformNode and MockCodeNode -to ensure they work correctly with the TableTestRunner. -""" - -from configs import dify_config -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.nodes.code.limits import CodeNodeLimits -from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig -from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory -from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockCodeNode, MockTemplateTransformNode - -DEFAULT_CODE_LIMITS = CodeNodeLimits( - max_string_length=dify_config.CODE_MAX_STRING_LENGTH, - max_number=dify_config.CODE_MAX_NUMBER, - min_number=dify_config.CODE_MIN_NUMBER, - max_precision=dify_config.CODE_MAX_PRECISION, - max_depth=dify_config.CODE_MAX_DEPTH, - max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, - max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, - max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, -) - - -class _NoopCodeExecutor: - def execute(self, *, language: object, code: str, inputs: dict[str, object]) -> dict[str, object]: - _ = (language, code, inputs) - return {} - - def is_execution_error(self, error: Exception) -> bool: - _ = error - return False - - -class TestMockTemplateTransformNode: - """Test cases for MockTemplateTransformNode.""" - - def test_mock_template_transform_node_default_output(self): - """Test that MockTemplateTransformNode processes templates with Jinja2.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "output" in result.outputs - # The template "Hello {{ name }}" with no name variable renders as "Hello " - assert result.outputs["output"] == "Hello " - - def test_mock_template_transform_node_custom_output(self): - """Test that MockTemplateTransformNode returns custom configured output.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config with custom output - mock_config = ( - MockConfigBuilder().with_node_output("template_node_1", {"output": "Custom template output"}).build() - ) - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "output" in result.outputs - assert result.outputs["output"] == "Custom template output" - - def test_mock_template_transform_node_error_simulation(self): - """Test that MockTemplateTransformNode can simulate errors.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config with error - mock_config = MockConfigBuilder().with_node_error("template_node_1", "Simulated template error").build() - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.FAILED - assert result.error == "Simulated template error" - - def test_mock_template_transform_node_with_variables(self): - """Test that MockTemplateTransformNode processes templates with variables.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - from graphon.variables import StringVariable - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - # Add a variable to the pool - variable_pool.add(["test", "name"], StringVariable(name="name", value="World", selector=["test", "name"])) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config with a variable - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [{"variable": "name", "value_selector": ["test", "name"]}], - "template": "Hello {{ name }}!", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "output" in result.outputs - assert result.outputs["output"] == "Hello World!" - - -class TestMockCodeNode: - """Test cases for MockCodeNode.""" - - def test_mock_code_node_default_output(self): - """Test that MockCodeNode returns default output.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "result = 'test'", - "outputs": {}, # Empty outputs for default case - }, - } - - # Create mock node - mock_node = MockCodeNode( - id="code_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - code_executor=_NoopCodeExecutor(), - code_limits=DEFAULT_CODE_LIMITS, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "result" in result.outputs - assert result.outputs["result"] == "mocked code execution result" - - def test_mock_code_node_with_output_schema(self): - """Test that MockCodeNode generates outputs based on schema.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config with output schema - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "name = 'test'\ncount = 42\nitems = ['a', 'b']", - "outputs": { - "name": {"type": "string"}, - "count": {"type": "number"}, - "items": {"type": "array[string]"}, - }, - }, - } - - # Create mock node - mock_node = MockCodeNode( - id="code_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - code_executor=_NoopCodeExecutor(), - code_limits=DEFAULT_CODE_LIMITS, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "name" in result.outputs - assert result.outputs["name"] == "mocked_name" - assert "count" in result.outputs - assert result.outputs["count"] == 42 - assert "items" in result.outputs - assert result.outputs["items"] == ["item1", "item2"] - - def test_mock_code_node_custom_output(self): - """Test that MockCodeNode returns custom configured output.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config with custom output - mock_config = ( - MockConfigBuilder() - .with_node_output("code_node_1", {"result": "Custom code result", "status": "success"}) - .build() - ) - - # Create node config - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "result = 'test'", - "outputs": {}, # Empty outputs for default case - }, - } - - # Create mock node - mock_node = MockCodeNode( - id="code_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - code_executor=_NoopCodeExecutor(), - code_limits=DEFAULT_CODE_LIMITS, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "result" in result.outputs - assert result.outputs["result"] == "Custom code result" - assert "status" in result.outputs - assert result.outputs["status"] == "success" - - -class TestMockNodeFactory: - """Test cases for MockNodeFactory with new node types.""" - - def test_code_and_template_nodes_mocked_by_default(self): - """Test that CODE and TEMPLATE_TRANSFORM nodes are mocked by default (they require SSRF proxy).""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create factory - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - # Verify that CODE and TEMPLATE_TRANSFORM ARE mocked by default (they require SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Verify that other third-party service nodes ARE also mocked by default - assert factory.should_mock_node(BuiltinNodeTypes.LLM) - assert factory.should_mock_node(BuiltinNodeTypes.AGENT) - - def test_factory_creates_mock_template_transform_node(self): - """Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create factory - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create node through factory - node = factory.create_node(node_config) - - # Verify the correct mock type was created - assert isinstance(node, MockTemplateTransformNode) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - def test_factory_creates_mock_code_node(self): - """Test that MockNodeFactory creates MockCodeNode for code type.""" - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=[], - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create factory - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - # Create node config - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "result = 42", - "outputs": {}, # Required field for CodeNodeData - }, - } - - # Create node through factory - node = factory.create_node(node_config) - - # Verify the correct mock type was created - assert isinstance(node, MockCodeNode) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py deleted file mode 100644 index cb5200f8dc..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -Simple test to validate the auto-mock system without external dependencies. -""" - -import sys - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from graphon.enums import BuiltinNodeTypes -from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig -from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory - - -def test_mock_config_builder(): - """Test the MockConfigBuilder fluent interface.""" - print("Testing MockConfigBuilder...") - - config = ( - MockConfigBuilder() - .with_llm_response("LLM response") - .with_agent_response("Agent response") - .with_tool_response({"tool": "output"}) - .with_retrieval_response("Retrieval content") - .with_http_response({"status_code": 201, "body": "created"}) - .with_node_output("node1", {"output": "value"}) - .with_node_error("node2", "error message") - .with_delays(True) - .build() - ) - - assert config.default_llm_response == "LLM response" - assert config.default_agent_response == "Agent response" - assert config.default_tool_response == {"tool": "output"} - assert config.default_retrieval_response == "Retrieval content" - assert config.default_http_response == {"status_code": 201, "body": "created"} - assert config.simulate_delays is True - - node1_config = config.get_node_config("node1") - assert node1_config is not None - assert node1_config.outputs == {"output": "value"} - - node2_config = config.get_node_config("node2") - assert node2_config is not None - assert node2_config.error == "error message" - - print("✓ MockConfigBuilder test passed") - - -def test_mock_config_operations(): - """Test MockConfig operations.""" - print("Testing MockConfig operations...") - - config = MockConfig() - - # Test setting node outputs - config.set_node_outputs("test_node", {"result": "test_value"}) - node_config = config.get_node_config("test_node") - assert node_config is not None - assert node_config.outputs == {"result": "test_value"} - - # Test setting node error - config.set_node_error("error_node", "Test error") - error_config = config.get_node_config("error_node") - assert error_config is not None - assert error_config.error == "Test error" - - # Test default configs by node type - config.set_default_config(BuiltinNodeTypes.LLM, {"temperature": 0.7}) - llm_config = config.get_default_config(BuiltinNodeTypes.LLM) - assert llm_config == {"temperature": 0.7} - - print("✓ MockConfig operations test passed") - - -def test_node_mock_config(): - """Test NodeMockConfig.""" - print("Testing NodeMockConfig...") - - # Test with custom handler - def custom_handler(node): - return {"custom": "output"} - - node_config = NodeMockConfig( - node_id="test_node", outputs={"text": "test"}, error=None, delay=0.5, custom_handler=custom_handler - ) - - assert node_config.node_id == "test_node" - assert node_config.outputs == {"text": "test"} - assert node_config.delay == 0.5 - assert node_config.custom_handler is not None - - # Test custom handler - result = node_config.custom_handler(None) - assert result == {"custom": "output"} - - print("✓ NodeMockConfig test passed") - - -def test_mock_factory_detection(): - """Test MockNodeFactory node type detection.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - print("Testing MockNodeFactory detection...") - - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # Test that third-party service nodes are identified for mocking - assert factory.should_mock_node(BuiltinNodeTypes.LLM) - assert factory.should_mock_node(BuiltinNodeTypes.AGENT) - assert factory.should_mock_node(BuiltinNodeTypes.TOOL) - assert factory.should_mock_node(BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL) - assert factory.should_mock_node(BuiltinNodeTypes.HTTP_REQUEST) - assert factory.should_mock_node(BuiltinNodeTypes.PARAMETER_EXTRACTOR) - assert factory.should_mock_node(BuiltinNodeTypes.DOCUMENT_EXTRACTOR) - - # Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Test that non-service nodes are not mocked - assert not factory.should_mock_node(BuiltinNodeTypes.START) - assert not factory.should_mock_node(BuiltinNodeTypes.END) - assert not factory.should_mock_node(BuiltinNodeTypes.IF_ELSE) - assert not factory.should_mock_node(BuiltinNodeTypes.VARIABLE_AGGREGATOR) - - print("✓ MockNodeFactory detection test passed") - - -def test_mock_factory_registration(): - """Test registering and unregistering mock node types.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from graphon.entities import GraphInitParams - from graphon.runtime import GraphRuntimeState, VariablePool - - print("Testing MockNodeFactory registration...") - - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Unregister mock - factory.unregister_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - assert not factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Register custom mock (using a dummy class for testing) - class DummyMockNode: - pass - - factory.register_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM, DummyMockNode) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - print("✓ MockNodeFactory registration test passed") - - -def run_all_tests(): - """Run all tests.""" - print("\n=== Running Auto-Mock System Tests ===\n") - - try: - test_mock_config_builder() - test_mock_config_operations() - test_node_mock_config() - test_mock_factory_detection() - test_mock_factory_registration() - - print("\n=== All tests passed! ✅ ===\n") - return True - except AssertionError as e: - print(f"\n❌ Test failed: {e}") - return False - except Exception as e: - print(f"\n❌ Unexpected error: {e}") - import traceback - - traceback.print_exc() - return False - - -if __name__ == "__main__": - success = run_all_tests() - sys.exit(0 if success else 1) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index 37b43bd374..8311a1e847 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,18 +4,10 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.entities import WorkflowStartReason from graphon.graph import Graph -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from graphon.graph_engine.config import GraphEngineConfig -from graphon.graph_engine.graph_engine import GraphEngine +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, @@ -31,6 +23,14 @@ from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState, VariablePool + +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py deleted file mode 100644 index 59e54bd39a..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ /dev/null @@ -1,336 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.graph import Graph -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from graphon.graph_engine.config import GraphEngineConfig -from graphon.graph_engine.graph_engine import GraphEngine -from graphon.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig, NodeMockConfig -from .test_mock_nodes import MockLLMNode - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def submission_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: - self._forms_by_node_id = dict(forms_by_node_id) - - def get_form(self, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_node_id.get(node_id) - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in resume scenario") - - -class DelayedHumanInputNode(HumanInputNode): - def __init__(self, delay_seconds: float, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._delay_seconds = delay_seconds - - def _run(self): - if self._delay_seconds > 0: - time.sleep(self._delay_seconds) - yield from super()._run() - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - - human_a_config = {"id": "human_a", "data": human_data.model_dump()} - human_a = HumanInputNode( - id=human_a_config["id"], - config=human_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), - ) - - human_b_config = {"id": "human_b", "data": human_data.model_dump()} - human_b = DelayedHumanInputNode( - id=human_b_config["id"], - config=human_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), - delay_seconds=0.2, - ) - - llm_data = LLMNodeData( - title="LLM A", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt A", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_config = {"id": "llm_a", "data": llm_data.model_dump()} - llm_a = MockLLMNode( - id=llm_config["id"], - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_a, from_node_id="start") - .add_node(human_b, from_node_id="start") - .add_node(llm_a, from_node_id="human_a", source_handle="approve") - .build() - ) - - -def test_parallel_human_input_pause_preserves_node_finished() -> None: - runtime_state = _build_runtime_state() - - runtime_state.graph_execution.start() - runtime_state.register_paused_node("human_a") - runtime_state.register_paused_node("human_b") - - submitted = StaticForm( - form_id="form-a", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - pending = StaticForm( - form_id="form-b", - rendered="rendered", - is_submitted=False, - action_id=None, - data=None, - status_value=HumanInputFormStatus.WAITING, - ) - repo = StaticRepo({"human_a": submitted, "human_b": pending}) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - - graph = _build_graph(runtime_state, repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - events = list(engine.run()) - - llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) - llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) - human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) - graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) - graph_started = any(isinstance(e, GraphRunStartedEvent) for e in events) - - assert graph_started - assert graph_paused - assert human_b_pause - assert llm_started - assert llm_succeeded - - -def test_parallel_human_input_pause_preserves_node_finished_after_snapshot_resume() -> None: - base_state = _build_runtime_state() - base_state.graph_execution.start() - base_state.register_paused_node("human_a") - base_state.register_paused_node("human_b") - snapshot = base_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - submitted = StaticForm( - form_id="form-a", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - pending = StaticForm( - form_id="form-b", - rendered="rendered", - is_submitted=False, - action_id=None, - data=None, - status_value=HumanInputFormStatus.WAITING, - ) - repo = StaticRepo({"human_a": submitted, "human_b": pending}) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - - graph = _build_graph(resumed_state, repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=resumed_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - events = list(engine.run()) - - start_event = next(e for e in events if isinstance(e, GraphRunStartedEvent)) - assert start_event.reason is WorkflowStartReason.RESUMPTION - - llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) - llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) - human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) - graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) - - assert graph_paused - assert human_b_pause - assert llm_started - assert llm_succeeded diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py deleted file mode 100644 index 1a43734462..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ /dev/null @@ -1,286 +0,0 @@ -""" -Test for parallel streaming workflow behavior. - -This test validates that: -- LLM 1 always speaks English -- LLM 2 always speaks Chinese -- 2 LLMs run parallel, but LLM 2 will output before LLM 1 -- All chunks should be sent before Answer Node started -""" - -import time -from unittest.mock import MagicMock, patch -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.model_manager import ModelInstance -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from core.workflow.system_variables import build_system_variables -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_events import ( - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.node_events import NodeRunResult, StreamCompletedEvent -from graphon.nodes.llm.node import LLMNode -from graphon.runtime import GraphRuntimeState, VariablePool -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_table_runner import TableTestRunner - - -def create_llm_generator_with_delay(chunks: list[str], delay: float = 0.1): - """Create a generator that simulates LLM streaming output with delay""" - - def llm_generator(self): - for i, chunk in enumerate(chunks): - time.sleep(delay) # Simulate network delay - yield NodeRunStreamChunkEvent( - id=str(uuid4()), - node_id=self.id, - node_type=self.node_type, - selector=[self.id, "text"], - chunk=chunk, - is_final=i == len(chunks) - 1, - ) - - # Complete response - full_text = "".join(chunks) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": full_text}, - ) - ) - - return llm_generator - - -def test_parallel_streaming_workflow(): - """ - Test parallel streaming workflow to verify: - 1. All chunks from LLM 2 are output before LLM 1 - 2. At least one chunk from LLM 2 is output before LLM 1 completes (Success) - 3. At least one chunk from LLM 1 is output before LLM 2 completes (EXPECTED TO FAIL) - 4. All chunks are output before End begins - 5. The final output content matches the order defined in the Answer - - Test setup: - - LLM 1 outputs English (slower) - - LLM 2 outputs Chinese (faster) - - Both run in parallel - - This test is expected to FAIL because chunks are currently buffered - until after node completion instead of streaming during execution. - """ - runner = TableTestRunner() - - # Load the workflow configuration - fixture_data = runner.workflow_runner.load_fixture("multilingual_parallel_llm_streaming_workflow") - workflow_config = fixture_data.get("workflow", {}) - graph_config = workflow_config.get("graph", {}) - - # Create graph initialization parameters - init_params = build_test_graph_init_params( - workflow_id="test_workflow", - graph_config=graph_config, - tenant_id="test_tenant", - app_id="test_app", - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - call_depth=0, - ) - - # Create variable pool with system variables - system_variables = build_system_variables( - user_id="test_user", - app_id="test_app", - workflow_id=init_params.workflow_id, - files=[], - query="Tell me about yourself", # User query - ) - variable_pool = VariablePool( - system_variables=system_variables, - user_inputs={}, - ) - - # Create graph runtime state - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - # Create node factory and graph - node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state) - with patch.object( - DifyNodeFactory, "_build_model_instance_for_llm_node", return_value=MagicMock(spec=ModelInstance), autospec=True - ): - graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=get_default_root_node_id(graph_config), - ) - - # Create the graph engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Define LLM outputs - llm1_chunks = ["Hello", ", ", "I", " ", "am", " ", "an", " ", "AI", " ", "assistant", "."] # English (slower) - llm2_chunks = ["你好", ",", "我", "是", "AI", "助手", "。"] # Chinese (faster) - - # Create generators with different delays (LLM 2 is faster) - llm1_generator = create_llm_generator_with_delay(llm1_chunks, delay=0.05) # Slower - llm2_generator = create_llm_generator_with_delay(llm2_chunks, delay=0.01) # Faster - - # Track which LLM node is being called - llm_call_order = [] - generators = { - "1754339718571": llm1_generator, # LLM 1 node ID - "1754339725656": llm2_generator, # LLM 2 node ID - } - - def mock_llm_run(self): - llm_call_order.append(self.id) - generator = generators.get(self.id) - if generator: - yield from generator(self) - else: - raise Exception(f"Unexpected LLM node ID: {self.id}") - - # Execute with mocked LLMs - with patch.object(LLMNode, "_run", new=mock_llm_run): - events = list(engine.run()) - - # Check for successful completion - success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] - assert len(success_events) > 0, "Workflow should complete successfully" - - # Get all streaming chunk events - stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] - - # Get Answer node start event - answer_start_events = [ - e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.ANSWER - ] - assert len(answer_start_events) == 1, f"Expected 1 Answer node start event, got {len(answer_start_events)}" - answer_start_event = answer_start_events[0] - - # Find the index of Answer node start - answer_start_index = events.index(answer_start_event) - - # Collect chunk events by node - llm1_chunks_events = [e for e in stream_chunk_events if e.node_id == "1754339718571"] - llm2_chunks_events = [e for e in stream_chunk_events if e.node_id == "1754339725656"] - - # Verify both LLMs produced chunks - assert len(llm1_chunks_events) == len(llm1_chunks), ( - f"Expected {len(llm1_chunks)} chunks from LLM 1, got {len(llm1_chunks_events)}" - ) - assert len(llm2_chunks_events) == len(llm2_chunks), ( - f"Expected {len(llm2_chunks)} chunks from LLM 2, got {len(llm2_chunks_events)}" - ) - - # 1. Verify chunk ordering based on actual implementation - llm1_chunk_indices = [events.index(e) for e in llm1_chunks_events] - llm2_chunk_indices = [events.index(e) for e in llm2_chunks_events] - - # In the current implementation, chunks may be interleaved or in a specific order - # Update this based on actual behavior observed - if llm1_chunk_indices and llm2_chunk_indices: - # Check the actual ordering - if LLM 2 chunks come first (as seen in debug) - assert max(llm2_chunk_indices) < min(llm1_chunk_indices), ( - f"All LLM 2 chunks should be output before LLM 1 chunks. " - f"LLM 2 chunk indices: {llm2_chunk_indices}, LLM 1 chunk indices: {llm1_chunk_indices}" - ) - - # Get indices of all chunk events - chunk_indices = [events.index(e) for e in stream_chunk_events if e in llm1_chunks_events + llm2_chunks_events] - - # 4. Verify all chunks were sent before Answer node started - assert all(idx < answer_start_index for idx in chunk_indices), ( - "All LLM chunks should be sent before Answer node starts" - ) - - # The test has successfully verified: - # 1. Both LLMs run in parallel (they start at the same time) - # 2. LLM 2 (Chinese) outputs all its chunks before LLM 1 (English) due to faster processing - # 3. All LLM chunks are sent before the Answer node starts - - # Get LLM completion events - llm_completed_events = [ - (i, e) - for i, e in enumerate(events) - if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM - ] - - # Check LLM completion order - in the current implementation, LLMs run sequentially - # LLM 1 completes first, then LLM 2 runs and completes - assert len(llm_completed_events) == 2, f"Expected 2 LLM completion events, got {len(llm_completed_events)}" - llm2_complete_idx = next((i for i, e in llm_completed_events if e.node_id == "1754339725656"), None) - llm1_complete_idx = next((i for i, e in llm_completed_events if e.node_id == "1754339718571"), None) - assert llm2_complete_idx is not None, "LLM 2 completion event not found" - assert llm1_complete_idx is not None, "LLM 1 completion event not found" - # In the actual implementation, LLM 1 completes before LLM 2 (sequential execution) - assert llm1_complete_idx < llm2_complete_idx, ( - f"LLM 1 should complete before LLM 2 in sequential execution, but LLM 1 completed at {llm1_complete_idx} " - f"and LLM 2 completed at {llm2_complete_idx}" - ) - - # 2. In sequential execution, LLM 2 chunks appear AFTER LLM 1 completes - if llm2_chunk_indices: - # LLM 1 completes first, then LLM 2 starts streaming - assert min(llm2_chunk_indices) > llm1_complete_idx, ( - f"LLM 2 chunks should appear after LLM 1 completes in sequential execution. " - f"First LLM 2 chunk at index {min(llm2_chunk_indices)}, LLM 1 completed at index {llm1_complete_idx}" - ) - - # 3. In the current implementation, LLM 1 chunks appear after LLM 2 completes - # This is because chunks are buffered and output after both nodes complete - if llm1_chunk_indices and llm2_complete_idx: - # Check if LLM 1 chunks exist and where they appear relative to LLM 2 completion - # In current behavior, LLM 1 chunks typically appear after LLM 2 completes - pass # Skipping this check as the chunk ordering is implementation-dependent - - # CURRENT BEHAVIOR: Chunks are buffered and appear after node completion - # In the sequential execution, LLM 1 completes first without streaming, - # then LLM 2 streams its chunks - assert stream_chunk_events, "Expected streaming events, but got none" - - first_chunk_index = events.index(stream_chunk_events[0]) - llm_success_indices = [i for i, e in llm_completed_events] - - # Current implementation: LLM 1 completes first, then chunks start appearing - # This is the actual behavior we're testing - if llm_success_indices: - # At least one LLM (LLM 1) completes before any chunks appear - assert min(llm_success_indices) < first_chunk_index, ( - f"In current implementation, LLM 1 completes before chunks start streaming. " - f"First chunk at index {first_chunk_index}, LLM 1 completed at index {min(llm_success_indices)}" - ) - - # 5. Verify final output content matches the order defined in Answer node - # According to Answer node configuration: '{{#1754339725656.text#}}{{#1754339718571.text#}}' - # This means LLM 2 output should come first, then LLM 1 output - answer_complete_events = [ - e for e in events if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.ANSWER - ] - assert len(answer_complete_events) == 1, f"Expected 1 Answer completion event, got {len(answer_complete_events)}" - - answer_outputs = answer_complete_events[0].node_run_result.outputs - expected_answer_text = "你好,我是AI助手。Hello, I am an AI assistant." - - if "answer" in answer_outputs: - actual_answer_text = answer_outputs["answer"] - assert actual_answer_text == expected_answer_text, ( - f"Answer content should match the order defined in Answer node. " - f"Expected: '{expected_answer_text}', Got: '{actual_answer_text}'" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py deleted file mode 100644 index bcf123ee80..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ /dev/null @@ -1,311 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.graph import Graph -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from graphon.graph_engine.config import GraphEngineConfig -from graphon.graph_engine.graph_engine import GraphEngine -from graphon.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig, NodeMockConfig -from .test_mock_nodes import MockLLMNode - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def submission_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, form: HumanInputFormEntity) -> None: - self._form = form - - def get_form(self, node_id: str) -> HumanInputFormEntity | None: - if node_id != "human_pause": - return None - return self._form - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in this test") - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - llm_a_data = LLMNodeData( - title="LLM A", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt A", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_a_config = {"id": "llm_a", "data": llm_a_data.model_dump()} - llm_a = MockLLMNode( - id=llm_a_config["id"], - config=llm_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - llm_b_data = LLMNodeData( - title="LLM B", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt B", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_b_config = {"id": "llm_b", "data": llm_b_data.model_dump()} - llm_b = MockLLMNode( - id=llm_b_config["id"], - config=llm_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Pause here", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - human_config = {"id": "human_pause", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), - ) - - end_human_data = EndNodeData(title="End Human", outputs=[], desc=None) - end_human_config = {"id": "end_human", "data": end_human_data.model_dump()} - end_human = EndNode( - id=end_human_config["id"], - config=end_human_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(llm_a, from_node_id="start") - .add_node(human_node, from_node_id="start") - .add_node(llm_b, from_node_id="llm_a") - .add_node(end_human, from_node_id="human_pause", source_handle="approve") - .build() - ) - - -def _get_node_started_event(events: list[object], node_id: str) -> NodeRunStartedEvent | None: - for event in events: - if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: - return event - return None - - -def test_pause_defers_ready_nodes_until_resume() -> None: - runtime_state = _build_runtime_state() - - paused_form = StaticForm( - form_id="form-pause", - rendered="rendered", - is_submitted=False, - status_value=HumanInputFormStatus.WAITING, - ) - pause_repo = StaticRepo(paused_form) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - mock_config.set_node_config( - "llm_b", - NodeMockConfig(node_id="llm_b", outputs={"text": "LLM B output"}, delay=0.0), - ) - - graph = _build_graph(runtime_state, pause_repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - paused_events = list(engine.run()) - - assert any(isinstance(e, GraphRunPausedEvent) for e in paused_events) - assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in paused_events) - assert _get_node_started_event(paused_events, "llm_b") is None - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - submitted_form = StaticForm( - form_id="form-pause", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - resume_repo = StaticRepo(submitted_form) - - resumed_graph = _build_graph(resumed_state, resume_repo, mock_config) - resumed_engine = GraphEngine( - workflow_id="workflow", - graph=resumed_graph, - graph_runtime_state=resumed_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - resumed_events = list(resumed_engine.run()) - - start_event = next(e for e in resumed_events if isinstance(e, GraphRunStartedEvent)) - assert start_event.reason is WorkflowStartReason.RESUMPTION - - llm_b_started = _get_node_started_event(resumed_events, "llm_b") - assert llm_b_started is not None - assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_b" for e in resumed_events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py deleted file mode 100644 index 79d3d5bcfe..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ /dev/null @@ -1,219 +0,0 @@ -import datetime -import time -from typing import Any -from unittest.mock import MagicMock - -from core.repositories.human_input_repository import ( - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables -from graphon.entities.workflow_start_reason import WorkflowStartReason -from graphon.graph import Graph -from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from graphon.graph_engine.graph_engine import GraphEngine -from graphon.graph_events import ( - GraphEngineEvent, - GraphRunPausedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from graphon.graph_events.graph import GraphRunStartedEvent -from graphon.nodes.base.entities import OutputVariableEntity -from graphon.nodes.end.end_node import EndNode -from graphon.nodes.end.entities import EndNodeData -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.nodes.start.entities import StartNodeData -from graphon.nodes.start.start_node import StartNode -from graphon.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=build_system_variables( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="test-execution-id", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.submission_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = True - form_entity.selected_action_id = action_id - form_entity.submitted_data = {} - form_entity.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - repo.get_form.return_value = form_entity - return repo - - -def _mock_form_repository_without_submission() -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.submission_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = False - repo.create_form.return_value = form_entity - repo.get_form.return_value = None - return repo - - -def _build_human_input_graph( - runtime_state: GraphRuntimeState, - form_repository: HumanInputFormRepository, -) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="service-api", - call_depth=0, - ) - - start_data = StartNodeData(title="start", variables=[]) - start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="human", - form_content="Awaiting human input", - inputs=[], - user_actions=[ - UserAction(id="continue", title="Continue"), - ], - ) - human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - form_repository=form_repository, - runtime=DifyHumanInputNodeRuntime(params.run_context), - ) - - end_data = EndNodeData( - title="end", - outputs=[ - OutputVariableEntity(variable="result", value_selector=["human", "action_id"]), - ], - desc=None, - ) - end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_node) - .add_node(end_node, from_node_id="human", source_handle="continue") - .build() - ) - - -def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]: - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - ) - return list(engine.run()) - - -def _node_successes(events: list[GraphEngineEvent]) -> list[str]: - return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)] - - -def _node_start_event(events: list[GraphEngineEvent], node_id: str) -> NodeRunStartedEvent | None: - for event in events: - if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: - return event - return None - - -def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any: - segment = variable_pool.get(selector) - assert segment is not None - return getattr(segment, "value", segment) - - -def test_engine_resume_restores_state_and_completion(): - # Baseline run without pausing - baseline_state = _build_runtime_state() - baseline_repo = _mock_form_repository_with_submission(action_id="continue") - baseline_graph = _build_human_input_graph(baseline_state, baseline_repo) - baseline_events = _run_graph(baseline_graph, baseline_state) - assert baseline_events - first_paused_event = baseline_events[0] - assert isinstance(first_paused_event, GraphRunStartedEvent) - assert first_paused_event.reason is WorkflowStartReason.INITIAL - assert isinstance(baseline_events[-1], GraphRunSucceededEvent) - baseline_success_nodes = _node_successes(baseline_events) - - # Run with pause - paused_state = _build_runtime_state() - pause_repo = _mock_form_repository_without_submission() - paused_graph = _build_human_input_graph(paused_state, pause_repo) - paused_events = _run_graph(paused_graph, paused_state) - assert paused_events - first_paused_event = paused_events[0] - assert isinstance(first_paused_event, GraphRunStartedEvent) - assert first_paused_event.reason is WorkflowStartReason.INITIAL - assert isinstance(paused_events[-1], GraphRunPausedEvent) - snapshot = paused_state.dumps() - - # Resume from snapshot - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resume_repo = _mock_form_repository_with_submission(action_id="continue") - resumed_graph = _build_human_input_graph(resumed_state, resume_repo) - resumed_events = _run_graph(resumed_graph, resumed_state) - assert resumed_events - first_resumed_event = resumed_events[0] - assert isinstance(first_resumed_event, GraphRunStartedEvent) - assert first_resumed_event.reason is WorkflowStartReason.RESUMPTION - assert isinstance(resumed_events[-1], GraphRunSucceededEvent) - - combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events) - assert combined_success_nodes == baseline_success_nodes - - paused_human_started = _node_start_event(paused_events, "human") - resumed_human_started = _node_start_event(resumed_events, "human") - assert paused_human_started is not None - assert resumed_human_started is not None - assert paused_human_started.id == resumed_human_started.id - - assert baseline_state.outputs == resumed_state.outputs - assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value( - resumed_state.variable_pool, ("human", "__action_id") - ) - assert baseline_state.graph_execution.completed - assert resumed_state.graph_execution.completed diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py deleted file mode 100644 index 146b728dc2..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py +++ /dev/null @@ -1,268 +0,0 @@ -""" -Unit tests for Redis-based stop functionality in GraphEngine. - -Tests the integration of Redis command channel for stopping workflows -without user permission checks. -""" - -import json -from unittest.mock import MagicMock, Mock, patch - -import pytest -import redis - -from core.app.apps.base_app_queue_manager import AppQueueManager -from graphon.graph_engine.command_channels.redis_channel import RedisChannel -from graphon.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand -from graphon.graph_engine.manager import GraphEngineManager - - -class TestRedisStopIntegration: - """Test suite for Redis-based workflow stop functionality.""" - - def test_graph_engine_manager_sends_abort_command(self): - """Test that GraphEngineManager correctly sends abort command through Redis.""" - # Setup - task_id = "test-task-123" - expected_channel_key = f"workflow:{task_id}:commands" - - # Mock redis client - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - manager = GraphEngineManager(mock_redis) - - # Execute - manager.send_stop_command(task_id, reason="Test stop") - - # Verify - mock_redis.pipeline.assert_called_once() - - # Check that rpush was called with correct arguments - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - - # Verify the channel key - assert calls[0][0][0] == expected_channel_key - - # Verify the command data - command_json = calls[0][0][1] - command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.ABORT - assert command_data["reason"] == "Test stop" - - def test_graph_engine_manager_sends_pause_command(self): - """Test that GraphEngineManager correctly sends pause command through Redis.""" - task_id = "test-task-pause-123" - expected_channel_key = f"workflow:{task_id}:commands" - - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - manager = GraphEngineManager(mock_redis) - manager.send_pause_command(task_id, reason="Awaiting resources") - - mock_redis.pipeline.assert_called_once() - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - assert calls[0][0][0] == expected_channel_key - - command_json = calls[0][0][1] - command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.PAUSE.value - assert command_data["reason"] == "Awaiting resources" - - def test_graph_engine_manager_handles_redis_failure_gracefully(self): - """Test that GraphEngineManager handles Redis failures without raising exceptions.""" - task_id = "test-task-456" - - # Mock redis client to raise exception - mock_redis = MagicMock() - mock_redis.pipeline.side_effect = redis.ConnectionError("Redis connection failed") - manager = GraphEngineManager(mock_redis) - - # Should not raise exception - try: - manager.send_stop_command(task_id) - except Exception as e: - pytest.fail(f"GraphEngineManager.send_stop_command raised {e} unexpectedly") - - def test_app_queue_manager_no_user_check(self): - """Test that AppQueueManager.set_stop_flag_no_user_check works without user validation.""" - task_id = "test-task-789" - expected_cache_key = f"generate_task_stopped:{task_id}" - - # Mock redis client - mock_redis = MagicMock() - - with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): - # Execute - AppQueueManager.set_stop_flag_no_user_check(task_id) - - # Verify - mock_redis.setex.assert_called_once_with(expected_cache_key, 600, 1) - - def test_app_queue_manager_no_user_check_with_empty_task_id(self): - """Test that AppQueueManager.set_stop_flag_no_user_check handles empty task_id.""" - # Mock redis client - mock_redis = MagicMock() - - with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): - # Execute with empty task_id - AppQueueManager.set_stop_flag_no_user_check("") - - # Verify redis was not called - mock_redis.setex.assert_not_called() - - def test_redis_channel_send_abort_command(self): - """Test RedisChannel correctly serializes and sends AbortCommand.""" - # Setup - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - channel_key = "workflow:test:commands" - channel = RedisChannel(mock_redis, channel_key) - - # Create commands - abort_command = AbortCommand(reason="User requested stop") - pause_command = PauseCommand(reason="User requested pause") - - # Execute - channel.send_command(abort_command) - channel.send_command(pause_command) - - # Verify - mock_redis.pipeline.assert_called() - - # Check rpush was called - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 2 - assert calls[0][0][0] == channel_key - assert calls[1][0][0] == channel_key - - # Verify serialized commands - abort_command_json = calls[0][0][1] - abort_command_data = json.loads(abort_command_json) - assert abort_command_data["command_type"] == CommandType.ABORT.value - assert abort_command_data["reason"] == "User requested stop" - - pause_command_json = calls[1][0][1] - pause_command_data = json.loads(pause_command_json) - assert pause_command_data["command_type"] == CommandType.PAUSE.value - assert pause_command_data["reason"] == "User requested pause" - - # Check expire was set for each - assert mock_pipeline.expire.call_count == 2 - mock_pipeline.expire.assert_any_call(channel_key, 3600) - - def test_redis_channel_fetch_commands(self): - """Test RedisChannel correctly fetches and deserializes commands.""" - # Setup - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Mock command data - abort_command_json = json.dumps( - {"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None} - ) - pause_command_json = json.dumps( - {"command_type": CommandType.PAUSE.value, "reason": "Pause requested", "payload": None} - ) - - # Mock pipeline execute to return commands - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [ - [abort_command_json.encode(), pause_command_json.encode()], # lrange result - True, # delete result - ] - - channel_key = "workflow:test:commands" - channel = RedisChannel(mock_redis, channel_key) - - # Execute - commands = channel.fetch_commands() - - # Verify - assert len(commands) == 2 - assert isinstance(commands[0], AbortCommand) - assert commands[0].command_type == CommandType.ABORT - assert commands[0].reason == "Test abort" - assert isinstance(commands[1], PauseCommand) - assert commands[1].command_type == CommandType.PAUSE - assert commands[1].reason == "Pause requested" - - # Verify Redis operations - pending_pipe.get.assert_called_once_with(f"{channel_key}:pending") - pending_pipe.delete.assert_called_once_with(f"{channel_key}:pending") - fetch_pipe.lrange.assert_called_once_with(channel_key, 0, -1) - fetch_pipe.delete.assert_called_once_with(channel_key) - assert mock_redis.pipeline.call_count == 2 - - def test_redis_channel_fetch_commands_handles_invalid_json(self): - """Test RedisChannel gracefully handles invalid JSON in commands.""" - # Setup - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Mock invalid command data - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [ - [b"invalid json", b'{"command_type": "invalid_type"}'], # lrange result - True, # delete result - ] - - channel_key = "workflow:test:commands" - channel = RedisChannel(mock_redis, channel_key) - - # Execute - commands = channel.fetch_commands() - - # Should return empty list due to invalid commands - assert len(commands) == 0 - - def test_dual_stop_mechanism_compatibility(self): - """Test that both stop mechanisms can work together.""" - task_id = "test-task-dual" - - # Mock redis client - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): - # Execute both stop mechanisms - AppQueueManager.set_stop_flag_no_user_check(task_id) - GraphEngineManager(mock_redis).send_stop_command(task_id) - - # Verify legacy stop flag was set - expected_stop_flag_key = f"generate_task_stopped:{task_id}" - mock_redis.setex.assert_called_once_with(expected_stop_flag_key, 600, 1) - - # Verify command was sent through Redis channel - mock_redis.pipeline.assert_called() - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - assert calls[0][0][0] == f"workflow:{task_id}:commands" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py deleted file mode 100644 index 62ca7a630e..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Unit tests for response session creation.""" - -from __future__ import annotations - -import pytest - -from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, NodeType -from graphon.graph_engine.response_coordinator.session import ResponseSession -from graphon.nodes.base.template import Template, TextSegment - - -class DummyResponseNode: - """Minimal response-capable node for session tests.""" - - def __init__(self, *, node_id: str, node_type: NodeType, template: Template) -> None: - self.id = node_id - self.node_type = node_type - self.execution_type = NodeExecutionType.RESPONSE - self.state = NodeState.UNKNOWN - self._template = template - - def get_streaming_template(self) -> Template: - return self._template - - -class DummyNodeWithoutStreamingTemplate: - """Minimal node that violates the response-session contract.""" - - def __init__(self, *, node_id: str, node_type: NodeType) -> None: - self.id = node_id - self.node_type = node_type - self.execution_type = NodeExecutionType.RESPONSE - self.state = NodeState.UNKNOWN - - -def test_response_session_from_node_accepts_nodes_outside_previous_allowlist() -> None: - """Session creation depends on the streaming-template contract rather than node type.""" - node = DummyResponseNode( - node_id="llm-node", - node_type=BuiltinNodeTypes.LLM, - template=Template(segments=[TextSegment(text="hello")]), - ) - - session = ResponseSession.from_node(node) - - assert session.node_id == "llm-node" - assert session.template.segments == [TextSegment(text="hello")] - - -def test_response_session_from_node_requires_streaming_template_method() -> None: - """Allowed node types still need to implement the streaming-template contract.""" - node = DummyNodeWithoutStreamingTemplate(node_id="answer-node", node_type=BuiltinNodeTypes.ANSWER) - - with pytest.raises(TypeError, match="get_streaming_template"): - ResponseSession.from_node(node) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py deleted file mode 100644 index a359a5fef9..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py +++ /dev/null @@ -1,79 +0,0 @@ -from graphon.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunVariableUpdatedEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_streaming_conversation_variables(): - fixture_name = "test_streaming_conversation_variables" - - # The test expects the workflow to output the input query - # Since the workflow assigns sys.query to conversation variable "str" and then answers with it - input_query = "Hello, this is my test query" - - mock_config = MockConfigBuilder().build() - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=False, # Don't use auto mock since we want to test actual variable assignment - mock_config=mock_config, - query=input_query, # Pass query as the sys.query value - inputs={}, # No additional inputs needed - expected_outputs={"answer": input_query}, # Expecting the input query to be output - expected_event_sequence=[ - GraphRunStartedEvent, - # START node - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Variable Assigner node - NodeRunStartedEvent, - NodeRunVariableUpdatedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - # ANSWER node - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" - - -def test_streaming_conversation_variables_v1_overwrite_waits_for_assignment(): - fixture_name = "test_streaming_conversation_variables_v1_overwrite" - input_query = "overwrite-value" - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=False, - mock_config=MockConfigBuilder().build(), - query=input_query, - inputs={}, - expected_outputs={"answer": f"Current Value Of `conv_var` is:{input_query}"}, - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" - - events = result.events - conv_var_chunk_events = [ - event - for event in events - if isinstance(event, NodeRunStreamChunkEvent) and tuple(event.selector) == ("conversation", "conv_var") - ] - - assert conv_var_chunk_events, "Expected conversation variable chunk events to be emitted" - assert all(event.chunk == input_query for event in conv_var_chunk_events), ( - "Expected streamed conversation variable value to match the input query" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 81d68ba2aa..b11f957677 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -19,12 +19,7 @@ from functools import lru_cache from pathlib import Path from typing import Any -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.tools.utils.yaml_utils import _load_yaml_file -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool -from graphon.entities.graph_init_params import GraphInitParams +from graphon.entities import GraphInitParams from graphon.graph import Graph from graphon.graph_engine import GraphEngine, GraphEngineConfig from graphon.graph_engine.command_channels import InMemoryChannel @@ -44,6 +39,12 @@ from graphon.variables import ( StringVariable, ) +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.tools.utils.yaml_utils import _load_yaml_file +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool + from .test_mock_config import MockConfig from .test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py deleted file mode 100644 index a7309f64de..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Validate conversation variable updates inside an iteration workflow. - -This test uses the ``update-conversation-variable-in-iteration`` fixture, which -routes ``sys.query`` into the conversation variable ``answer`` from within an -iteration container. The workflow should surface that updated conversation -variable in the final answer output. - -Code nodes in the fixture are mocked because their concrete outputs are not -relevant to verifying variable propagation semantics. -""" - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_update_conversation_variable_in_iteration(): - fixture_name = "update-conversation-variable-in-iteration" - user_query = "ensure conversation variable syncs" - - mock_config = ( - MockConfigBuilder() - .with_node_output("1759032363865", {"result": [1]}) - .with_node_output("1759032476318", {"result": ""}) - .build() - ) - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - query=user_query, - expected_outputs={"answer": user_query}, - description="Conversation variable updated within iteration should flow to answer output.", - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - - assert result.success, f"Workflow execution failed: {result.error}" - assert result.actual_outputs is not None - assert result.actual_outputs.get("answer") == user_query diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py deleted file mode 100644 index 2ad41037a9..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py +++ /dev/null @@ -1,58 +0,0 @@ -from unittest.mock import patch - -import pytest - -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode - -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -class TestVariableAggregator: - """Test cases for the variable aggregator workflow.""" - - @pytest.mark.parametrize( - ("switch1", "switch2", "expected_group1", "expected_group2", "description"), - [ - (0, 0, "switch 1 off", "switch 2 off", "Both switches off"), - (0, 1, "switch 1 off", "switch 2 on", "Switch1 off, Switch2 on"), - (1, 0, "switch 1 on", "switch 2 off", "Switch1 on, Switch2 off"), - (1, 1, "switch 1 on", "switch 2 on", "Both switches on"), - ], - ) - def test_variable_aggregator_combinations( - self, - switch1: int, - switch2: int, - expected_group1: str, - expected_group2: str, - description: str, - ) -> None: - """Test all four combinations of switch1 and switch2.""" - - def mock_template_transform_run(self): - """Mock the TemplateTransformNode._run() method to return results based on node title.""" - title = self._node_data.title - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"output": title}) - - with patch.object( - TemplateTransformNode, - "_run", - mock_template_transform_run, - ): - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="dual_switch_variable_aggregator_workflow", - inputs={"switch1": switch1, "switch2": switch2}, - expected_outputs={"group1": expected_group1, "group2": expected_group2}, - description=description, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs == test_case.expected_outputs, ( - f"Output mismatch: expected {test_case.expected_outputs}, got {result.actual_outputs}" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py deleted file mode 100644 index 60cab77c0a..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py +++ /dev/null @@ -1,129 +0,0 @@ -import time -import uuid -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool -from graphon.entities import GraphInitParams -from graphon.graph import Graph -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import InMemoryChannel -from graphon.graph_engine.layers.base import GraphEngineLayer -from graphon.graph_events import NodeRunVariableUpdatedEvent -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import StringVariable - -DEFAULT_NODE_ID = "node_id" - - -class CaptureVariableUpdateLayer(GraphEngineLayer): - def __init__(self) -> None: - super().__init__() - self.events: list[NodeRunVariableUpdatedEvent] = [] - self.observed_values: list[object | None] = [] - - def on_graph_start(self) -> None: - pass - - def on_event(self, event) -> None: - if not isinstance(event, NodeRunVariableUpdatedEvent): - return - - current_value = self.graph_runtime_state.variable_pool.get(event.variable.selector) - self.events.append(event) - self.observed_values.append(None if current_value is None else current_value.value) - - def on_graph_end(self, error: Exception | None) -> None: - pass - - -def test_graph_engine_applies_variable_updates_before_notifying_layers(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "over-write", - "input_variable_selector": ["node_id", "test_string_variable"], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - variable_pool = VariablePool() - add_variables_to_pool( - variable_pool, - build_bootstrap_variables( - system_variables=build_system_variables(conversation_id=str(uuid.uuid4())), - conversation_variables=[ - StringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value="the first value", - ) - ], - ), - ) - variable_pool.add( - [DEFAULT_NODE_ID, "test_string_variable"], - StringVariable( - id=str(uuid4()), - name="test_string_variable", - value="the second value", - ), - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - engine = GraphEngine( - workflow_id="workflow-id", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - capture_layer = CaptureVariableUpdateLayer() - engine.layer(capture_layer) - - events = list(engine.run()) - - update_events = [event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)] - assert len(update_events) == 1 - assert update_events[0].variable.value == "the second value" - - current_value = graph_runtime_state.variable_pool.get(["conversation", "test_conversation_variable"]) - assert current_value is not None - assert current_value.value == "the second value" - - assert len(capture_layer.events) == 1 - assert capture_layer.observed_values == ["the second value"] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py b/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py deleted file mode 100644 index 85132674b8..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py +++ /dev/null @@ -1,148 +0,0 @@ -import queue -from collections.abc import Generator -from datetime import UTC, datetime, timedelta -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from graphon.graph_engine.ready_queue import InMemoryReadyQueue -from graphon.graph_engine.worker import Worker -from graphon.graph_events import NodeRunFailedEvent, NodeRunStartedEvent - - -def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None: - fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - mock_datetime = mocker.patch("graphon.graph_engine.worker.datetime") - mock_datetime.now.return_value = fixed_time.replace(tzinfo=UTC) - - worker = Worker( - ready_queue=InMemoryReadyQueue(), - event_queue=queue.Queue(), - graph=MagicMock(), - layers=[], - ) - node = SimpleNamespace( - execution_id="exec-1", - id="node-1", - node_type=BuiltinNodeTypes.LLM, - ) - - event = worker._build_fallback_failure_event(node, RuntimeError("boom")) - - assert event.start_at == fixed_time - assert event.finished_at == fixed_time - assert event.error == "boom" - assert event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED - assert event.node_run_result.error == "boom" - assert event.node_run_result.error_type == "RuntimeError" - - -def test_worker_fallback_failure_event_reuses_observed_start_time() -> None: - start_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - failure_time = start_at + timedelta(seconds=5) - captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = [] - - class FakeNode: - execution_id = "exec-1" - id = "node-1" - node_type = BuiltinNodeTypes.LLM - - def ensure_execution_id(self) -> str: - return self.execution_id - - def run(self) -> Generator[NodeRunStartedEvent, None, None]: - yield NodeRunStartedEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - node_title="LLM", - start_at=start_at, - ) - - worker = Worker( - ready_queue=MagicMock(), - event_queue=MagicMock(), - graph=MagicMock(nodes={"node-1": FakeNode()}), - layers=[], - ) - - worker._ready_queue.get.side_effect = ["node-1"] - - def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None: - captured_events.append(event) - if len(captured_events) == 1: - raise RuntimeError("queue boom") - worker.stop() - - worker._event_queue.put.side_effect = put_side_effect - - with patch("graphon.graph_engine.worker.datetime") as mock_datetime: - mock_datetime.now.return_value = failure_time.replace(tzinfo=UTC) - worker.run() - - fallback_event = captured_events[-1] - - assert isinstance(fallback_event, NodeRunFailedEvent) - assert fallback_event.start_at == start_at - assert fallback_event.finished_at == failure_time - assert fallback_event.error == "queue boom" - assert fallback_event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED - - -def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_times() -> None: - parent_start = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - child_start = parent_start + timedelta(seconds=3) - failure_time = parent_start + timedelta(seconds=5) - captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = [] - - class FakeIterationNode: - execution_id = "iteration-exec" - id = "iteration-node" - node_type = BuiltinNodeTypes.ITERATION - - def ensure_execution_id(self) -> str: - return self.execution_id - - def run(self) -> Generator[NodeRunStartedEvent, None, None]: - yield NodeRunStartedEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - node_title="Iteration", - start_at=parent_start, - ) - yield NodeRunStartedEvent( - id="child-exec", - node_id="child-node", - node_type=BuiltinNodeTypes.LLM, - node_title="LLM", - start_at=child_start, - in_iteration_id=self.id, - ) - - worker = Worker( - ready_queue=MagicMock(), - event_queue=MagicMock(), - graph=MagicMock(nodes={"iteration-node": FakeIterationNode()}), - layers=[], - ) - - worker._ready_queue.get.side_effect = ["iteration-node"] - - def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None: - captured_events.append(event) - if len(captured_events) == 2: - raise RuntimeError("queue boom") - worker.stop() - - worker._event_queue.put.side_effect = put_side_effect - - with patch("graphon.graph_engine.worker.datetime") as mock_datetime: - mock_datetime.now.return_value = failure_time.replace(tzinfo=UTC) - worker.run() - - fallback_event = captured_events[-1] - - assert isinstance(fallback_event, NodeRunFailedEvent) - assert fallback_event.start_at == parent_start - assert fallback_event.finished_at == failure_time diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py index 1f4509af9a..cbc920705c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py @@ -1,8 +1,9 @@ from unittest.mock import patch +from graphon.enums import BuiltinNodeTypes + from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer -from graphon.enums import BuiltinNodeTypes def test_transform_passes_conversation_id_to_tool_file_message_transformer() -> None: diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py index c86de7f6e6..59dd763b59 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py @@ -1,9 +1,10 @@ from types import SimpleNamespace from unittest.mock import Mock, patch -from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport from graphon.model_runtime.entities.model_entities import ModelType +from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport + def test_fetch_model_reuses_single_model_assembly(): provider_configuration = SimpleNamespace( diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 9c0ad25b58..7195471eb6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,14 +2,15 @@ import time import uuid from unittest.mock import MagicMock +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.answer.answer_node import AnswerNode +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables from extensions.ext_database import db -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.nodes.answer.answer_node import AnswerNode -from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index ec4cef1955..343bcd3919 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,10 +1,10 @@ import pytest - -from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node +from core.workflow.node_factory import get_node_type_classes_mapping + # Ensures that all production node classes are imported and registered. _ = get_node_type_classes_mapping() diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index ef0df55995..b9371a34f4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -1,7 +1,6 @@ import types from collections.abc import Mapping -from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node @@ -14,6 +13,8 @@ from graphon.nodes.variable_assigner.v2.node import ( VariableAssignerNode as VariableAssignerV2, ) +from core.workflow.node_factory import get_node_type_classes_mapping + def test_variable_assigner_latest_prefers_highest_numeric_version(): # Act diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index ce0c9b79c6..d155124c50 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,4 +1,3 @@ -from configs import dify_config from graphon.nodes.code.code_node import CodeNode from graphon.nodes.code.entities import CodeLanguage, CodeNodeData from graphon.nodes.code.exc import ( @@ -9,6 +8,8 @@ from graphon.nodes.code.exc import ( from graphon.nodes.code.limits import CodeNodeLimits from graphon.variables.types import SegmentType +from configs import dify_config + CodeNode._limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, max_number=dify_config.CODE_MAX_NUMBER, diff --git a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py deleted file mode 100644 index 20fe2c1a74..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py +++ /dev/null @@ -1,352 +0,0 @@ -import pytest -from pydantic import ValidationError - -from graphon.nodes.code.entities import CodeLanguage, CodeNodeData -from graphon.variables.types import SegmentType - - -class TestCodeNodeDataOutput: - """Test suite for CodeNodeData.Output model.""" - - def test_output_with_string_type(self): - """Test Output with STRING type.""" - output = CodeNodeData.Output(type=SegmentType.STRING) - - assert output.type == SegmentType.STRING - assert output.children is None - - def test_output_with_number_type(self): - """Test Output with NUMBER type.""" - output = CodeNodeData.Output(type=SegmentType.NUMBER) - - assert output.type == SegmentType.NUMBER - assert output.children is None - - def test_output_with_boolean_type(self): - """Test Output with BOOLEAN type.""" - output = CodeNodeData.Output(type=SegmentType.BOOLEAN) - - assert output.type == SegmentType.BOOLEAN - - def test_output_with_object_type(self): - """Test Output with OBJECT type.""" - output = CodeNodeData.Output(type=SegmentType.OBJECT) - - assert output.type == SegmentType.OBJECT - - def test_output_with_array_string_type(self): - """Test Output with ARRAY_STRING type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_STRING) - - assert output.type == SegmentType.ARRAY_STRING - - def test_output_with_array_number_type(self): - """Test Output with ARRAY_NUMBER type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER) - - assert output.type == SegmentType.ARRAY_NUMBER - - def test_output_with_array_object_type(self): - """Test Output with ARRAY_OBJECT type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_OBJECT) - - assert output.type == SegmentType.ARRAY_OBJECT - - def test_output_with_array_boolean_type(self): - """Test Output with ARRAY_BOOLEAN type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN) - - assert output.type == SegmentType.ARRAY_BOOLEAN - - def test_output_with_nested_children(self): - """Test Output with nested children for OBJECT type.""" - child_output = CodeNodeData.Output(type=SegmentType.STRING) - parent_output = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={"name": child_output}, - ) - - assert parent_output.type == SegmentType.OBJECT - assert parent_output.children is not None - assert "name" in parent_output.children - assert parent_output.children["name"].type == SegmentType.STRING - - def test_output_with_deeply_nested_children(self): - """Test Output with deeply nested children.""" - inner_child = CodeNodeData.Output(type=SegmentType.NUMBER) - middle_child = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={"value": inner_child}, - ) - outer_output = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={"nested": middle_child}, - ) - - assert outer_output.children is not None - assert outer_output.children["nested"].children is not None - assert outer_output.children["nested"].children["value"].type == SegmentType.NUMBER - - def test_output_with_multiple_children(self): - """Test Output with multiple children.""" - output = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - "age": CodeNodeData.Output(type=SegmentType.NUMBER), - "active": CodeNodeData.Output(type=SegmentType.BOOLEAN), - }, - ) - - assert output.children is not None - assert len(output.children) == 3 - assert output.children["name"].type == SegmentType.STRING - assert output.children["age"].type == SegmentType.NUMBER - assert output.children["active"].type == SegmentType.BOOLEAN - - def test_output_rejects_invalid_type(self): - """Test Output rejects invalid segment types.""" - with pytest.raises(ValidationError): - CodeNodeData.Output(type=SegmentType.FILE) - - def test_output_rejects_array_file_type(self): - """Test Output rejects ARRAY_FILE type.""" - with pytest.raises(ValidationError): - CodeNodeData.Output(type=SegmentType.ARRAY_FILE) - - -class TestCodeNodeDataDependency: - """Test suite for CodeNodeData.Dependency model.""" - - def test_dependency_basic(self): - """Test Dependency with name and version.""" - dependency = CodeNodeData.Dependency(name="numpy", version="1.24.0") - - assert dependency.name == "numpy" - assert dependency.version == "1.24.0" - - def test_dependency_with_complex_version(self): - """Test Dependency with complex version string.""" - dependency = CodeNodeData.Dependency(name="pandas", version=">=2.0.0,<3.0.0") - - assert dependency.name == "pandas" - assert dependency.version == ">=2.0.0,<3.0.0" - - def test_dependency_with_empty_version(self): - """Test Dependency with empty version.""" - dependency = CodeNodeData.Dependency(name="requests", version="") - - assert dependency.name == "requests" - assert dependency.version == "" - - -class TestCodeNodeData: - """Test suite for CodeNodeData model.""" - - def test_code_node_data_python3(self): - """Test CodeNodeData with Python3 language.""" - data = CodeNodeData( - title="Test Code Node", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'result': 42}", - outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)}, - ) - - assert data.title == "Test Code Node" - assert data.code_language == CodeLanguage.PYTHON3 - assert data.code == "def main(): return {'result': 42}" - assert "result" in data.outputs - assert data.dependencies is None - - def test_code_node_data_javascript(self): - """Test CodeNodeData with JavaScript language.""" - data = CodeNodeData( - title="JS Code Node", - variables=[], - code_language=CodeLanguage.JAVASCRIPT, - code="function main() { return { result: 'hello' }; }", - outputs={"result": CodeNodeData.Output(type=SegmentType.STRING)}, - ) - - assert data.code_language == CodeLanguage.JAVASCRIPT - assert "result" in data.outputs - assert data.outputs["result"].type == SegmentType.STRING - - def test_code_node_data_with_dependencies(self): - """Test CodeNodeData with dependencies.""" - data = CodeNodeData( - title="Code with Deps", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="import numpy as np\ndef main(): return {'sum': 10}", - outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)}, - dependencies=[ - CodeNodeData.Dependency(name="numpy", version="1.24.0"), - CodeNodeData.Dependency(name="pandas", version="2.0.0"), - ], - ) - - assert data.dependencies is not None - assert len(data.dependencies) == 2 - assert data.dependencies[0].name == "numpy" - assert data.dependencies[1].name == "pandas" - - def test_code_node_data_with_multiple_outputs(self): - """Test CodeNodeData with multiple outputs.""" - data = CodeNodeData( - title="Multi Output", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'name': 'test', 'count': 5, 'items': ['a', 'b']}", - outputs={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - "count": CodeNodeData.Output(type=SegmentType.NUMBER), - "items": CodeNodeData.Output(type=SegmentType.ARRAY_STRING), - }, - ) - - assert len(data.outputs) == 3 - assert data.outputs["name"].type == SegmentType.STRING - assert data.outputs["count"].type == SegmentType.NUMBER - assert data.outputs["items"].type == SegmentType.ARRAY_STRING - - def test_code_node_data_with_object_output(self): - """Test CodeNodeData with nested object output.""" - data = CodeNodeData( - title="Object Output", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'user': {'name': 'John', 'age': 30}}", - outputs={ - "user": CodeNodeData.Output( - type=SegmentType.OBJECT, - children={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - "age": CodeNodeData.Output(type=SegmentType.NUMBER), - }, - ), - }, - ) - - assert data.outputs["user"].type == SegmentType.OBJECT - assert data.outputs["user"].children is not None - assert len(data.outputs["user"].children) == 2 - - def test_code_node_data_with_array_object_output(self): - """Test CodeNodeData with array of objects output.""" - data = CodeNodeData( - title="Array Object Output", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'users': [{'name': 'A'}, {'name': 'B'}]}", - outputs={ - "users": CodeNodeData.Output( - type=SegmentType.ARRAY_OBJECT, - children={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - }, - ), - }, - ) - - assert data.outputs["users"].type == SegmentType.ARRAY_OBJECT - assert data.outputs["users"].children is not None - - def test_code_node_data_empty_code(self): - """Test CodeNodeData with empty code.""" - data = CodeNodeData( - title="Empty Code", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="", - outputs={}, - ) - - assert data.code == "" - assert len(data.outputs) == 0 - - def test_code_node_data_multiline_code(self): - """Test CodeNodeData with multiline code.""" - multiline_code = """ -def main(): - result = 0 - for i in range(10): - result += i - return {'sum': result} -""" - data = CodeNodeData( - title="Multiline Code", - variables=[], - code_language=CodeLanguage.PYTHON3, - code=multiline_code, - outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)}, - ) - - assert "for i in range(10)" in data.code - assert "result += i" in data.code - - def test_code_node_data_with_special_characters_in_code(self): - """Test CodeNodeData with special characters in code.""" - code_with_special = "def main(): return {'msg': 'Hello\\nWorld\\t!'}" - data = CodeNodeData( - title="Special Chars", - variables=[], - code_language=CodeLanguage.PYTHON3, - code=code_with_special, - outputs={"msg": CodeNodeData.Output(type=SegmentType.STRING)}, - ) - - assert "\\n" in data.code - assert "\\t" in data.code - - def test_code_node_data_with_unicode_in_code(self): - """Test CodeNodeData with unicode characters in code.""" - unicode_code = "def main(): return {'greeting': '你好世界'}" - data = CodeNodeData( - title="Unicode Code", - variables=[], - code_language=CodeLanguage.PYTHON3, - code=unicode_code, - outputs={"greeting": CodeNodeData.Output(type=SegmentType.STRING)}, - ) - - assert "你好世界" in data.code - - def test_code_node_data_empty_dependencies_list(self): - """Test CodeNodeData with empty dependencies list.""" - data = CodeNodeData( - title="No Deps", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {}", - outputs={}, - dependencies=[], - ) - - assert data.dependencies is not None - assert len(data.dependencies) == 0 - - def test_code_node_data_with_boolean_array_output(self): - """Test CodeNodeData with boolean array output.""" - data = CodeNodeData( - title="Boolean Array", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'flags': [True, False, True]}", - outputs={"flags": CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN)}, - ) - - assert data.outputs["flags"].type == SegmentType.ARRAY_BOOLEAN - - def test_code_node_data_with_number_array_output(self): - """Test CodeNodeData with number array output.""" - data = CodeNodeData( - title="Number Array", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'values': [1, 2, 3, 4, 5]}", - outputs={"values": CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER)}, - ) - - assert data.outputs["values"].type == SegmentType.ARRAY_NUMBER diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index 1d76067ec2..fb03ae9998 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,7 +1,8 @@ +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent + from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent class _VarSeg: diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py deleted file mode 100644 index f1a48f49b9..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py +++ /dev/null @@ -1,33 +0,0 @@ -from graphon.nodes.http_request import build_http_request_config - - -def test_build_http_request_config_uses_literal_defaults(): - config = build_http_request_config() - - assert config.max_connect_timeout == 10 - assert config.max_read_timeout == 600 - assert config.max_write_timeout == 600 - assert config.max_binary_size == 10 * 1024 * 1024 - assert config.max_text_size == 1 * 1024 * 1024 - assert config.ssl_verify is True - assert config.ssrf_default_max_retries == 3 - - -def test_build_http_request_config_supports_explicit_overrides(): - config = build_http_request_config( - max_connect_timeout=5, - max_read_timeout=30, - max_write_timeout=40, - max_binary_size=2048, - max_text_size=1024, - ssl_verify=False, - ssrf_default_max_retries=8, - ) - - assert config.max_connect_timeout == 5 - assert config.max_read_timeout == 30 - assert config.max_write_timeout == 40 - assert config.max_binary_size == 2048 - assert config.max_text_size == 1024 - assert config.ssl_verify is False - assert config.ssrf_default_max_retries == 8 diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py deleted file mode 100644 index 88895608d9..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py +++ /dev/null @@ -1,233 +0,0 @@ -import json -from unittest.mock import Mock, PropertyMock, patch - -import httpx -import pytest - -from graphon.nodes.http_request.entities import Response - - -@pytest.fixture -def mock_response(): - response = Mock(spec=httpx.Response) - response.headers = {} - return response - - -def test_is_file_with_attachment_disposition(mock_response): - """Test is_file when content-disposition header contains 'attachment'""" - mock_response.headers = {"content-disposition": "attachment; filename=test.pdf", "content-type": "application/pdf"} - response = Response(mock_response) - assert response.is_file - - -def test_is_file_with_filename_disposition(mock_response): - """Test is_file when content-disposition header contains filename parameter""" - mock_response.headers = {"content-disposition": "inline; filename=test.pdf", "content-type": "application/pdf"} - response = Response(mock_response) - assert response.is_file - - -@pytest.mark.parametrize("content_type", ["application/pdf", "image/jpeg", "audio/mp3", "video/mp4"]) -def test_is_file_with_file_content_types(mock_response, content_type): - """Test is_file with various file content types""" - mock_response.headers = {"content-type": content_type} - # Mock binary content - type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) - response = Response(mock_response) - assert response.is_file, f"Content type {content_type} should be identified as a file" - - -@pytest.mark.parametrize( - "content_type", - [ - "application/json", - "application/xml", - "application/javascript", - "application/x-www-form-urlencoded", - "application/yaml", - "application/graphql", - ], -) -def test_text_based_application_types(mock_response, content_type): - """Test common text-based application types are not identified as files""" - mock_response.headers = {"content-type": content_type} - response = Response(mock_response) - assert not response.is_file, f"Content type {content_type} should not be identified as a file" - - -@pytest.mark.parametrize( - ("content", "content_type"), - [ - (b'{"key": "value"}', "application/octet-stream"), - (b"[1, 2, 3]", "application/unknown"), - (b"function test() {}", "application/x-unknown"), - (b"test", "application/binary"), - (b"var x = 1;", "application/data"), - ], -) -def test_content_based_detection(mock_response, content, content_type): - """Test content-based detection for text-like content""" - mock_response.headers = {"content-type": content_type} - type(mock_response).content = PropertyMock(return_value=content) - response = Response(mock_response) - assert not response.is_file, f"Content {content} with type {content_type} should not be identified as a file" - - -@pytest.mark.parametrize( - ("content", "content_type"), - [ - (bytes([0x00, 0xFF] * 512), "application/octet-stream"), - (bytes([0x89, 0x50, 0x4E, 0x47]), "application/unknown"), # PNG magic numbers - (bytes([0xFF, 0xD8, 0xFF]), "application/binary"), # JPEG magic numbers - ], -) -def test_binary_content_detection(mock_response, content, content_type): - """Test content-based detection for binary content""" - mock_response.headers = {"content-type": content_type} - type(mock_response).content = PropertyMock(return_value=content) - response = Response(mock_response) - assert response.is_file, f"Binary content with type {content_type} should be identified as a file" - - -@pytest.mark.parametrize( - ("content_type", "expected_main_type"), - [ - ("x-world/x-vrml", "model"), # VRML 3D model - ("font/ttf", "application"), # TrueType font - ("text/csv", "text"), # CSV text file - ("unknown/xyz", None), # Unknown type - ], -) -def test_mimetype_based_detection(mock_response, content_type, expected_main_type): - """Test detection using mimetypes.guess_type for non-application content types""" - mock_response.headers = {"content-type": content_type} - type(mock_response).content = PropertyMock(return_value=bytes([0x00])) # Dummy content - - with patch("graphon.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type: - # Mock the return value based on expected_main_type - if expected_main_type: - mock_guess_type.return_value = (f"{expected_main_type}/subtype", None) - else: - mock_guess_type.return_value = (None, None) - - response = Response(mock_response) - - # Check if the result matches our expectation - if expected_main_type in ("application", "image", "audio", "video"): - assert response.is_file, f"Content type {content_type} should be identified as a file" - else: - assert not response.is_file, f"Content type {content_type} should not be identified as a file" - - # Verify that guess_type was called - mock_guess_type.assert_called_once() - - -def test_is_file_with_inline_disposition(mock_response): - """Test is_file when content-disposition is 'inline'""" - mock_response.headers = {"content-disposition": "inline", "content-type": "application/pdf"} - # Mock binary content - type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) - response = Response(mock_response) - assert response.is_file - - -def test_is_file_with_no_content_disposition(mock_response): - """Test is_file when no content-disposition header is present""" - mock_response.headers = {"content-type": "application/pdf"} - # Mock binary content - type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) - response = Response(mock_response) - assert response.is_file - - -# UTF-8 Encoding Tests -@pytest.mark.parametrize( - ("content_bytes", "expected_text", "description"), - [ - # Chinese UTF-8 bytes - ( - b'{"message": "\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c"}', - '{"message": "你好世界"}', - "Chinese characters UTF-8", - ), - # Japanese UTF-8 bytes - ( - b'{"message": "\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf"}', - '{"message": "こんにちは"}', - "Japanese characters UTF-8", - ), - # Korean UTF-8 bytes - ( - b'{"message": "\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94"}', - '{"message": "안녕하세요"}', - "Korean characters UTF-8", - ), - # Arabic UTF-8 - (b'{"text": "\xd9\x85\xd8\xb1\xd8\xad\xd8\xa8\xd8\xa7"}', '{"text": "مرحبا"}', "Arabic characters UTF-8"), - # European characters UTF-8 - (b'{"text": "Caf\xc3\xa9 M\xc3\xbcnchen"}', '{"text": "Café München"}', "European accented characters"), - # Simple ASCII - (b'{"text": "Hello World"}', '{"text": "Hello World"}', "Simple ASCII text"), - ], -) -def test_text_property_utf8_decoding(mock_response, content_bytes, expected_text, description): - """Test that Response.text properly decodes UTF-8 content with charset_normalizer""" - mock_response.headers = {"content-type": "application/json; charset=utf-8"} - type(mock_response).content = PropertyMock(return_value=content_bytes) - # Mock httpx response.text to return something different (simulating potential encoding issues) - mock_response.text = "incorrect-fallback-text" # To ensure we are not falling back to httpx's text property - - response = Response(mock_response) - - # Our enhanced text property should decode properly using charset_normalizer - assert response.text == expected_text, ( - f"Failed for {description}: got {repr(response.text)}, expected {repr(expected_text)}" - ) - - -def test_text_property_fallback_to_httpx(mock_response): - """Test that Response.text falls back to httpx.text when charset_normalizer fails""" - mock_response.headers = {"content-type": "application/json"} - - # Create malformed UTF-8 bytes - malformed_bytes = b'{"text": "\xff\xfe\x00\x00 invalid"}' - type(mock_response).content = PropertyMock(return_value=malformed_bytes) - - # Mock httpx.text to return some fallback value - fallback_text = '{"text": "fallback"}' - mock_response.text = fallback_text - - response = Response(mock_response) - - # Should fall back to httpx's text when charset_normalizer fails - assert response.text == fallback_text - - -@pytest.mark.parametrize( - ("json_content", "description"), - [ - # JSON with escaped Unicode (like Flask jsonify()) - ('{"message": "\\u4f60\\u597d\\u4e16\\u754c"}', "JSON with escaped Unicode"), - # JSON with mixed escape sequences and UTF-8 - ('{"mixed": "Hello \\u4f60\\u597d"}', "Mixed escaped and regular text"), - # JSON with complex escape sequences - ('{"complex": "\\ud83d\\ude00\\u4f60\\u597d"}', "Emoji and Chinese escapes"), - ], -) -def test_text_property_with_escaped_unicode(mock_response, json_content, description): - """Test Response.text with JSON containing Unicode escape sequences""" - mock_response.headers = {"content-type": "application/json"} - - content_bytes = json_content.encode("utf-8") - type(mock_response).content = PropertyMock(return_value=content_bytes) - mock_response.text = json_content # httpx would return the same for valid UTF-8 - - response = Response(mock_response) - - # Should preserve the escape sequences (valid JSON) - assert response.text == json_content, f"Failed for {description}" - - # The text should be valid JSON that can be parsed back to proper Unicode - parsed = json.loads(response.text) - assert isinstance(parsed, dict), f"Invalid JSON for {description}" diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index be7cc073db..a5026b40cf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,8 +1,4 @@ import pytest - -from configs import dify_config -from core.helper.ssrf_proxy import ssrf_proxy -from core.workflow.system_variables import default_system_variables from graphon.file.file_manager import file_manager from graphon.nodes.http_request import ( BodyData, @@ -16,6 +12,10 @@ from graphon.nodes.http_request.exc import AuthorizationConfigError from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool +from configs import dify_config +from core.helper.ssrf_proxy import ssrf_proxy +from core.workflow.system_variables import default_system_variables + HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index a3cadc0681..4705b3f76e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -3,17 +3,17 @@ from typing import Any import httpx import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response -from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index 1d6a4da7c4..d16e1233ac 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,6 +1,7 @@ -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients from graphon.runtime import VariablePool +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients + def test_render_body_template_replaces_variable_values(): config = EmailDeliveryConfig( diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 5f28a07606..a2cdbbf132 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -2,14 +2,41 @@ Unit tests for human input node entities. """ +from collections.abc import Mapping +from dataclasses import dataclass, field +from datetime import datetime, timedelta from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock import pytest +from graphon.entities import GraphInitParams +from graphon.node_events import PauseRequestedEvent +from graphon.node_events.node import StreamCompletedEvent +from graphon.nodes.human_input.entities import ( + FormInput, + FormInputDefault, + HumanInputNodeData, + UserAction, +) +from graphon.nodes.human_input.enums import ( + ButtonStyle, + FormInputType, + HumanInputFormStatus, + PlaceholderType, + TimeoutUnit, +) +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.runtime import GraphRuntimeState, VariablePool from pydantic import ValidationError from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from core.repositories.human_input_repository import HumanInputFormRepository +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRecipientEntity, + HumanInputFormRepository, +) from core.workflow.human_input_compat import ( DeliveryMethodType, EmailDeliveryConfig, @@ -23,24 +50,90 @@ from core.workflow.human_input_compat import ( ) from core.workflow.node_runtime import DifyHumanInputNodeRuntime from core.workflow.system_variables import build_system_variables -from graphon.entities import GraphInitParams -from graphon.node_events import PauseRequestedEvent -from graphon.node_events.node import StreamCompletedEvent -from graphon.nodes.human_input.entities import ( - FormInput, - FormInputDefault, - HumanInputNodeData, - UserAction, -) -from graphon.nodes.human_input.enums import ( - ButtonStyle, - FormInputType, - PlaceholderType, - TimeoutUnit, -) -from graphon.nodes.human_input.human_input_node import HumanInputNode -from graphon.runtime import GraphRuntimeState, VariablePool -from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository +from libs.datetime_utils import naive_utc_now + + +@dataclass +class _InMemoryFormEntity(HumanInputFormEntity): + form_id: str + rendered: str + token: str | None = None + action_id: str | None = None + data: Mapping[str, Any] | None = None + is_submitted: bool = False + status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING + expiration: datetime = field(default_factory=lambda: naive_utc_now() + timedelta(days=1)) + + @property + def id(self) -> str: + return self.form_id + + @property + def submission_token(self) -> str | None: + return self.token + + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: + return [] + + @property + def rendered_content(self) -> str: + return self.rendered + + @property + def selected_action_id(self) -> str | None: + return self.action_id + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return self.data + + @property + def submitted(self) -> bool: + return self.is_submitted + + @property + def status(self) -> HumanInputFormStatus: + return self.status_value + + @property + def expiration_time(self) -> datetime: + return self.expiration + + +class InMemoryHumanInputFormRepository(HumanInputFormRepository): + """Minimal in-memory repository for Dify-owned HumanInputNode behavior tests.""" + + def __init__(self) -> None: + self._form_counter = 0 + self.created_params: list[FormCreateParams] = [] + self.created_forms: list[_InMemoryFormEntity] = [] + self._forms_by_node_id: dict[str, _InMemoryFormEntity] = {} + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + self.created_params.append(params) + self._form_counter += 1 + form_id = f"form-{self._form_counter}" + entity = _InMemoryFormEntity( + form_id=form_id, + rendered=params.rendered_content, + token=f"token-{form_id}", + ) + self.created_forms.append(entity) + self._forms_by_node_id[params.node_id] = entity + return entity + + def get_form(self, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_node_id.get(node_id) + + def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: + if not self.created_forms: + raise AssertionError("no form has been created to attach submission data") + entity = self.created_forms[-1] + entity.action_id = action_id + entity.data = form_data or {} + entity.is_submitted = True + entity.status_value = HumanInputFormStatus.SUBMITTED class TestDeliveryMethod: diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index fc4497f010..52802c7ce1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,10 +1,7 @@ import datetime from types import SimpleNamespace -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import default_system_variables -from graphon.entities.graph_init_params import GraphInitParams +from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes from graphon.graph_events import ( NodeRunHumanInputFormFilledEvent, @@ -14,6 +11,10 @@ from graphon.graph_events import ( from graphon.nodes.human_input.enums import HumanInputFormStatus from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import default_system_variables from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py deleted file mode 100644 index 8cc91bdb54..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py +++ /dev/null @@ -1,339 +0,0 @@ -from graphon.nodes.iteration.entities import ( - ErrorHandleMode, - IterationNodeData, - IterationStartNodeData, - IterationState, -) - - -class TestErrorHandleMode: - """Test suite for ErrorHandleMode enum.""" - - def test_terminated_value(self): - """Test TERMINATED enum value.""" - assert ErrorHandleMode.TERMINATED == "terminated" - assert ErrorHandleMode.TERMINATED.value == "terminated" - - def test_continue_on_error_value(self): - """Test CONTINUE_ON_ERROR enum value.""" - assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error" - assert ErrorHandleMode.CONTINUE_ON_ERROR.value == "continue-on-error" - - def test_remove_abnormal_output_value(self): - """Test REMOVE_ABNORMAL_OUTPUT enum value.""" - assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT == "remove-abnormal-output" - assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT.value == "remove-abnormal-output" - - def test_error_handle_mode_is_str_enum(self): - """Test ErrorHandleMode is a string enum.""" - assert isinstance(ErrorHandleMode.TERMINATED, str) - assert isinstance(ErrorHandleMode.CONTINUE_ON_ERROR, str) - assert isinstance(ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, str) - - def test_error_handle_mode_comparison(self): - """Test ErrorHandleMode can be compared with strings.""" - assert ErrorHandleMode.TERMINATED == "terminated" - assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error" - - def test_all_error_handle_modes(self): - """Test all ErrorHandleMode values are accessible.""" - modes = list(ErrorHandleMode) - - assert len(modes) == 3 - assert ErrorHandleMode.TERMINATED in modes - assert ErrorHandleMode.CONTINUE_ON_ERROR in modes - assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT in modes - - -class TestIterationNodeData: - """Test suite for IterationNodeData model.""" - - def test_iteration_node_data_basic(self): - """Test IterationNodeData with basic configuration.""" - data = IterationNodeData( - title="Test Iteration", - iterator_selector=["node1", "output"], - output_selector=["iteration", "result"], - ) - - assert data.title == "Test Iteration" - assert data.iterator_selector == ["node1", "output"] - assert data.output_selector == ["iteration", "result"] - - def test_iteration_node_data_default_values(self): - """Test IterationNodeData default values.""" - data = IterationNodeData( - title="Default Test", - iterator_selector=["start", "items"], - output_selector=["iter", "out"], - ) - - assert data.parent_loop_id is None - assert data.is_parallel is False - assert data.parallel_nums == 10 - assert data.error_handle_mode == ErrorHandleMode.TERMINATED - assert data.flatten_output is True - - def test_iteration_node_data_parallel_mode(self): - """Test IterationNodeData with parallel mode enabled.""" - data = IterationNodeData( - title="Parallel Iteration", - iterator_selector=["node", "list"], - output_selector=["iter", "output"], - is_parallel=True, - parallel_nums=5, - ) - - assert data.is_parallel is True - assert data.parallel_nums == 5 - - def test_iteration_node_data_custom_parallel_nums(self): - """Test IterationNodeData with custom parallel numbers.""" - data = IterationNodeData( - title="Custom Parallel", - iterator_selector=["a", "b"], - output_selector=["c", "d"], - parallel_nums=20, - ) - - assert data.parallel_nums == 20 - - def test_iteration_node_data_continue_on_error(self): - """Test IterationNodeData with continue on error mode.""" - data = IterationNodeData( - title="Continue Error", - iterator_selector=["x", "y"], - output_selector=["z", "w"], - error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR, - ) - - assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR - - def test_iteration_node_data_remove_abnormal_output(self): - """Test IterationNodeData with remove abnormal output mode.""" - data = IterationNodeData( - title="Remove Abnormal", - iterator_selector=["input", "array"], - output_selector=["output", "result"], - error_handle_mode=ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, - ) - - assert data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT - - def test_iteration_node_data_flatten_output_disabled(self): - """Test IterationNodeData with flatten output disabled.""" - data = IterationNodeData( - title="No Flatten", - iterator_selector=["a"], - output_selector=["b"], - flatten_output=False, - ) - - assert data.flatten_output is False - - def test_iteration_node_data_with_parent_loop_id(self): - """Test IterationNodeData with parent loop ID.""" - data = IterationNodeData( - title="Nested Loop", - iterator_selector=["parent", "items"], - output_selector=["child", "output"], - parent_loop_id="parent_loop_123", - ) - - assert data.parent_loop_id == "parent_loop_123" - - def test_iteration_node_data_complex_selectors(self): - """Test IterationNodeData with complex selectors.""" - data = IterationNodeData( - title="Complex Selectors", - iterator_selector=["node1", "output", "data", "items"], - output_selector=["iteration", "result", "value"], - ) - - assert len(data.iterator_selector) == 4 - assert len(data.output_selector) == 3 - - def test_iteration_node_data_all_options(self): - """Test IterationNodeData with all options configured.""" - data = IterationNodeData( - title="Full Config", - iterator_selector=["start", "list"], - output_selector=["end", "result"], - parent_loop_id="outer_loop", - is_parallel=True, - parallel_nums=15, - error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR, - flatten_output=False, - ) - - assert data.title == "Full Config" - assert data.parent_loop_id == "outer_loop" - assert data.is_parallel is True - assert data.parallel_nums == 15 - assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR - assert data.flatten_output is False - - -class TestIterationStartNodeData: - """Test suite for IterationStartNodeData model.""" - - def test_iteration_start_node_data_basic(self): - """Test IterationStartNodeData basic creation.""" - data = IterationStartNodeData(title="Iteration Start") - - assert data.title == "Iteration Start" - - def test_iteration_start_node_data_with_description(self): - """Test IterationStartNodeData with description.""" - data = IterationStartNodeData( - title="Start Node", - desc="This is the start of iteration", - ) - - assert data.title == "Start Node" - assert data.desc == "This is the start of iteration" - - -class TestIterationState: - """Test suite for IterationState model.""" - - def test_iteration_state_default_values(self): - """Test IterationState default values.""" - state = IterationState() - - assert state.outputs == [] - assert state.current_output is None - - def test_iteration_state_with_outputs(self): - """Test IterationState with outputs.""" - state = IterationState(outputs=["result1", "result2", "result3"]) - - assert len(state.outputs) == 3 - assert state.outputs[0] == "result1" - assert state.outputs[2] == "result3" - - def test_iteration_state_with_current_output(self): - """Test IterationState with current output.""" - state = IterationState(current_output="current_value") - - assert state.current_output == "current_value" - - def test_iteration_state_get_last_output_with_outputs(self): - """Test get_last_output with outputs present.""" - state = IterationState(outputs=["first", "second", "last"]) - - result = state.get_last_output() - - assert result == "last" - - def test_iteration_state_get_last_output_empty(self): - """Test get_last_output with empty outputs.""" - state = IterationState(outputs=[]) - - result = state.get_last_output() - - assert result is None - - def test_iteration_state_get_last_output_single(self): - """Test get_last_output with single output.""" - state = IterationState(outputs=["only_one"]) - - result = state.get_last_output() - - assert result == "only_one" - - def test_iteration_state_get_current_output(self): - """Test get_current_output method.""" - state = IterationState(current_output={"key": "value"}) - - result = state.get_current_output() - - assert result == {"key": "value"} - - def test_iteration_state_get_current_output_none(self): - """Test get_current_output when None.""" - state = IterationState() - - result = state.get_current_output() - - assert result is None - - def test_iteration_state_with_complex_outputs(self): - """Test IterationState with complex output types.""" - state = IterationState( - outputs=[ - {"id": 1, "name": "first"}, - {"id": 2, "name": "second"}, - [1, 2, 3], - "string_output", - ] - ) - - assert len(state.outputs) == 4 - assert state.outputs[0] == {"id": 1, "name": "first"} - assert state.outputs[2] == [1, 2, 3] - - def test_iteration_state_with_none_outputs(self): - """Test IterationState with None values in outputs.""" - state = IterationState(outputs=["value1", None, "value3"]) - - assert len(state.outputs) == 3 - assert state.outputs[1] is None - - def test_iteration_state_get_last_output_with_none(self): - """Test get_last_output when last output is None.""" - state = IterationState(outputs=["first", None]) - - result = state.get_last_output() - - assert result is None - - def test_iteration_state_metadata_class(self): - """Test IterationState.MetaData class.""" - metadata = IterationState.MetaData(iterator_length=10) - - assert metadata.iterator_length == 10 - - def test_iteration_state_metadata_different_lengths(self): - """Test IterationState.MetaData with different lengths.""" - metadata1 = IterationState.MetaData(iterator_length=0) - metadata2 = IterationState.MetaData(iterator_length=100) - metadata3 = IterationState.MetaData(iterator_length=1000000) - - assert metadata1.iterator_length == 0 - assert metadata2.iterator_length == 100 - assert metadata3.iterator_length == 1000000 - - def test_iteration_state_outputs_modification(self): - """Test modifying IterationState outputs.""" - state = IterationState(outputs=[]) - - state.outputs.append("new_output") - state.outputs.append("another_output") - - assert len(state.outputs) == 2 - assert state.get_last_output() == "another_output" - - def test_iteration_state_current_output_update(self): - """Test updating current_output.""" - state = IterationState() - - state.current_output = "first_value" - assert state.get_current_output() == "first_value" - - state.current_output = "updated_value" - assert state.get_current_output() == "updated_value" - - def test_iteration_state_with_numeric_outputs(self): - """Test IterationState with numeric outputs.""" - state = IterationState(outputs=[1, 2, 3, 4, 5]) - - assert state.get_last_output() == 5 - assert len(state.outputs) == 5 - - def test_iteration_state_with_boolean_outputs(self): - """Test IterationState with boolean outputs.""" - state = IterationState(outputs=[True, False, True]) - - assert state.get_last_output() is True - assert state.outputs[1] is False diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py deleted file mode 100644 index 58b82aa893..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py +++ /dev/null @@ -1,438 +0,0 @@ -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from graphon.nodes.iteration.exc import ( - InvalidIteratorValueError, - IterationGraphNotFoundError, - IterationIndexNotFoundError, - IterationNodeError, - IteratorVariableNotFoundError, - StartNodeIdNotFoundError, -) -from graphon.nodes.iteration.iteration_node import IterationNode - - -class TestIterationNodeExceptions: - """Test suite for iteration node exceptions.""" - - def test_iteration_node_error_is_value_error(self): - """Test IterationNodeError inherits from ValueError.""" - error = IterationNodeError("test error") - - assert isinstance(error, ValueError) - assert str(error) == "test error" - - def test_iterator_variable_not_found_error(self): - """Test IteratorVariableNotFoundError.""" - error = IteratorVariableNotFoundError("Iterator variable not found") - - assert isinstance(error, IterationNodeError) - assert isinstance(error, ValueError) - assert "Iterator variable not found" in str(error) - - def test_invalid_iterator_value_error(self): - """Test InvalidIteratorValueError.""" - error = InvalidIteratorValueError("Invalid iterator value") - - assert isinstance(error, IterationNodeError) - assert "Invalid iterator value" in str(error) - - def test_start_node_id_not_found_error(self): - """Test StartNodeIdNotFoundError.""" - error = StartNodeIdNotFoundError("Start node ID not found") - - assert isinstance(error, IterationNodeError) - assert "Start node ID not found" in str(error) - - def test_iteration_graph_not_found_error(self): - """Test IterationGraphNotFoundError.""" - error = IterationGraphNotFoundError("Iteration graph not found") - - assert isinstance(error, IterationNodeError) - assert "Iteration graph not found" in str(error) - - def test_iteration_index_not_found_error(self): - """Test IterationIndexNotFoundError.""" - error = IterationIndexNotFoundError("Iteration index not found") - - assert isinstance(error, IterationNodeError) - assert "Iteration index not found" in str(error) - - def test_exception_with_empty_message(self): - """Test exception with empty message.""" - error = IterationNodeError("") - - assert str(error) == "" - - def test_exception_with_detailed_message(self): - """Test exception with detailed message.""" - error = IteratorVariableNotFoundError("Variable 'items' not found in node 'start_node'") - - assert "items" in str(error) - assert "start_node" in str(error) - - def test_all_exceptions_inherit_from_base(self): - """Test all exceptions inherit from IterationNodeError.""" - exceptions = [ - IteratorVariableNotFoundError("test"), - InvalidIteratorValueError("test"), - StartNodeIdNotFoundError("test"), - IterationGraphNotFoundError("test"), - IterationIndexNotFoundError("test"), - ] - - for exc in exceptions: - assert isinstance(exc, IterationNodeError) - assert isinstance(exc, ValueError) - - -class TestIterationNodeClassAttributes: - """Test suite for IterationNode class attributes.""" - - def test_node_type(self): - """Test IterationNode node_type attribute.""" - assert IterationNode.node_type == BuiltinNodeTypes.ITERATION - - def test_version(self): - """Test IterationNode version method.""" - version = IterationNode.version() - - assert version == "1" - - -class TestIterationNodeDefaultConfig: - """Test suite for IterationNode get_default_config.""" - - def test_get_default_config_returns_dict(self): - """Test get_default_config returns a dictionary.""" - config = IterationNode.get_default_config() - - assert isinstance(config, dict) - - def test_get_default_config_type(self): - """Test get_default_config includes type.""" - config = IterationNode.get_default_config() - - assert config.get("type") == "iteration" - - def test_get_default_config_has_config_section(self): - """Test get_default_config has config section.""" - config = IterationNode.get_default_config() - - assert "config" in config - assert isinstance(config["config"], dict) - - def test_get_default_config_is_parallel_default(self): - """Test get_default_config is_parallel default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["is_parallel"] is False - - def test_get_default_config_parallel_nums_default(self): - """Test get_default_config parallel_nums default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["parallel_nums"] == 10 - - def test_get_default_config_error_handle_mode_default(self): - """Test get_default_config error_handle_mode default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["error_handle_mode"] == ErrorHandleMode.TERMINATED - - def test_get_default_config_flatten_output_default(self): - """Test get_default_config flatten_output default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["flatten_output"] is True - - def test_get_default_config_with_none_filters(self): - """Test get_default_config with None filters.""" - config = IterationNode.get_default_config(filters=None) - - assert config is not None - assert "type" in config - - def test_get_default_config_with_empty_filters(self): - """Test get_default_config with empty filters.""" - config = IterationNode.get_default_config(filters={}) - - assert config is not None - - -class TestIterationNodeInitialization: - """Test suite for IterationNode initialization.""" - - def test_init_node_data_basic(self): - """Test init_node_data with basic configuration.""" - node = IterationNode.__new__(IterationNode) - data = { - "title": "Test Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration", "result"], - } - - node.init_node_data(data) - - assert node._node_data.title == "Test Iteration" - assert node._node_data.iterator_selector == ["start", "items"] - - def test_init_node_data_with_parallel(self): - """Test init_node_data with parallel configuration.""" - node = IterationNode.__new__(IterationNode) - data = { - "title": "Parallel Iteration", - "iterator_selector": ["node", "list"], - "output_selector": ["out", "result"], - "is_parallel": True, - "parallel_nums": 5, - } - - node.init_node_data(data) - - assert node._node_data.is_parallel is True - assert node._node_data.parallel_nums == 5 - - def test_init_node_data_with_error_handle_mode(self): - """Test init_node_data with error handle mode.""" - node = IterationNode.__new__(IterationNode) - data = { - "title": "Error Handle Test", - "iterator_selector": ["a", "b"], - "output_selector": ["c", "d"], - "error_handle_mode": "continue-on-error", - } - - node.init_node_data(data) - - assert node._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR - - def test_get_title(self): - """Test _get_title method.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="My Iteration", - iterator_selector=["x"], - output_selector=["y"], - ) - - assert node._get_title() == "My Iteration" - - def test_get_description_none(self): - """Test _get_description returns None when not set.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - assert node._get_description() is None - - def test_get_description_with_value(self): - """Test _get_description with value.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - desc="This is a description", - iterator_selector=["a"], - output_selector=["b"], - ) - - assert node._get_description() == "This is a description" - - def test_node_data_property(self): - """Test node_data property returns node data.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Base Test", - iterator_selector=["x"], - output_selector=["y"], - ) - - result = node.node_data - - assert result == node._node_data - - -class TestIterationNodeDataValidation: - """Test suite for IterationNodeData validation scenarios.""" - - def test_valid_iteration_node_data(self): - """Test valid IterationNodeData creation.""" - data = IterationNodeData( - title="Valid Iteration", - iterator_selector=["start", "items"], - output_selector=["end", "result"], - ) - - assert data.title == "Valid Iteration" - - def test_iteration_node_data_with_all_error_modes(self): - """Test IterationNodeData with all error handle modes.""" - modes = [ - ErrorHandleMode.TERMINATED, - ErrorHandleMode.CONTINUE_ON_ERROR, - ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, - ] - - for mode in modes: - data = IterationNodeData( - title=f"Test {mode}", - iterator_selector=["a"], - output_selector=["b"], - error_handle_mode=mode, - ) - assert data.error_handle_mode == mode - - def test_iteration_node_data_parallel_configuration(self): - """Test IterationNodeData parallel configuration combinations.""" - configs = [ - (False, 10), - (True, 1), - (True, 5), - (True, 20), - (True, 100), - ] - - for is_parallel, parallel_nums in configs: - data = IterationNodeData( - title="Parallel Test", - iterator_selector=["x"], - output_selector=["y"], - is_parallel=is_parallel, - parallel_nums=parallel_nums, - ) - assert data.is_parallel == is_parallel - assert data.parallel_nums == parallel_nums - - def test_iteration_node_data_flatten_output_options(self): - """Test IterationNodeData flatten_output options.""" - data_flatten = IterationNodeData( - title="Flatten True", - iterator_selector=["a"], - output_selector=["b"], - flatten_output=True, - ) - - data_no_flatten = IterationNodeData( - title="Flatten False", - iterator_selector=["a"], - output_selector=["b"], - flatten_output=False, - ) - - assert data_flatten.flatten_output is True - assert data_no_flatten.flatten_output is False - - def test_iteration_node_data_complex_selectors(self): - """Test IterationNodeData with complex selectors.""" - data = IterationNodeData( - title="Complex", - iterator_selector=["node1", "output", "data", "items", "list"], - output_selector=["iteration", "result", "value", "final"], - ) - - assert len(data.iterator_selector) == 5 - assert len(data.output_selector) == 4 - - def test_iteration_node_data_single_element_selectors(self): - """Test IterationNodeData with single element selectors.""" - data = IterationNodeData( - title="Single", - iterator_selector=["items"], - output_selector=["result"], - ) - - assert len(data.iterator_selector) == 1 - assert len(data.output_selector) == 1 - - -class TestIterationNodeErrorStrategies: - """Test suite for IterationNode error strategies.""" - - def test_get_error_strategy_default(self): - """Test _get_error_strategy with default value.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - result = node._get_error_strategy() - - assert result is None or result == node._node_data.error_strategy - - def test_get_retry_config(self): - """Test _get_retry_config method.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - result = node._get_retry_config() - - assert result is not None - - def test_get_default_value_dict(self): - """Test _get_default_value_dict method.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - result = node._get_default_value_dict() - - assert isinstance(result, dict) - - -def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: - seen_configs: list[object] = [] - original_validate_python = NodeConfigDictAdapter.validate_python - - def record_validate_python(value: object): - seen_configs.append(value) - return original_validate_python(value) - - monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) - - child_node_config = { - "id": "answer-node", - "data": { - "type": "answer", - "title": "Answer", - "answer": "", - "iteration_id": "iteration-node", - }, - } - - IterationNode._extract_variable_selector_to_variable_mapping( - graph_config={ - "nodes": [ - { - "id": "iteration-node", - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration", "result"], - }, - }, - child_node_config, - ], - "edges": [], - }, - node_id="iteration-node", - node_data=IterationNodeData( - title="Iteration", - iterator_selector=["start", "items"], - output_selector=["iteration", "result"], - ), - ) - - assert seen_configs == [child_node_config] diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py deleted file mode 100644 index 4c3ad85fcd..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py +++ /dev/null @@ -1,201 +0,0 @@ -from threading import Event -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph_events import GraphRunAbortedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import IterationFailedEvent, IterationStartedEvent, StreamCompletedEvent -from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from graphon.nodes.iteration.exc import ChildGraphAbortedError -from graphon.nodes.iteration.iteration_node import IterationNode -from tests.workflow_test_utils import build_test_variable_pool - - -def _usage_with_tokens(total_tokens: int) -> LLMUsage: - usage = LLMUsage.empty_usage() - usage.total_tokens = total_tokens - return usage - - -class _AbortOnRequestGraphEngine: - def __init__(self, *, index: int, total_tokens: int) -> None: - variable_pool = build_test_variable_pool() - variable_pool.add(["iteration-node", "index"], index) - - self.started = Event() - self.abort_requested = Event() - self.finished = Event() - self.abort_reason: str | None = None - self.graph_runtime_state = SimpleNamespace( - variable_pool=variable_pool, - llm_usage=_usage_with_tokens(total_tokens), - ) - - def request_abort(self, reason: str | None = None) -> None: - self.abort_reason = reason - self.abort_requested.set() - - def run(self): - self.started.set() - assert self.abort_requested.wait(1), "parallel sibling never received an abort request" - self.finished.set() - yield GraphRunAbortedEvent(reason=self.abort_reason) - - -def _build_immediate_abort_graph_engine( - *, - index: int, - total_tokens: int, - wait_before_abort: Event | None = None, -) -> SimpleNamespace: - variable_pool = build_test_variable_pool() - variable_pool.add(["iteration-node", "index"], index) - - started = Event() - finished = Event() - - def run(): - started.set() - if wait_before_abort is not None: - assert wait_before_abort.wait(1), "parallel sibling never started" - finished.set() - yield GraphRunAbortedEvent(reason="quota exceeded") - - return SimpleNamespace( - graph_runtime_state=SimpleNamespace( - variable_pool=variable_pool, - llm_usage=_usage_with_tokens(total_tokens), - ), - run=run, - request_abort=lambda reason=None: None, - started=started, - finished=finished, - ) - - -def _build_iteration_node( - *, - error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED, - is_parallel: bool = False, -) -> IterationNode: - node = IterationNode.__new__(IterationNode) - node._node_id = "iteration-node" - node._node_data = IterationNodeData( - title="Iteration", - iterator_selector=["start", "items"], - output_selector=["iteration-node", "output"], - start_node_id="child-start", - is_parallel=is_parallel, - parallel_nums=2, - error_handle_mode=error_handle_mode, - ) - - variable_pool = build_test_variable_pool() - variable_pool.add(["start", "items"], ["first", "second"]) - node.graph_runtime_state = SimpleNamespace( - variable_pool=variable_pool, - llm_usage=LLMUsage.empty_usage(), - ) - return node - - -def test_run_single_iter_raises_child_graph_aborted_error_on_abort_event() -> None: - node = _build_iteration_node() - variable_pool = build_test_variable_pool() - variable_pool.add(["iteration-node", "index"], 0) - graph_engine = SimpleNamespace( - run=lambda: iter([GraphRunAbortedEvent(reason="quota exceeded")]), - ) - - with pytest.raises(ChildGraphAbortedError, match="quota exceeded"): - list( - node._run_single_iter( - variable_pool=variable_pool, - outputs=[], - graph_engine=graph_engine, - ) - ) - - -def test_iteration_run_fails_on_sequential_child_abort() -> None: - node = _build_iteration_node(error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR) - graph_engine = SimpleNamespace( - graph_runtime_state=SimpleNamespace( - variable_pool=build_test_variable_pool(), - llm_usage=LLMUsage.empty_usage(), - ) - ) - node._create_graph_engine = MagicMock(return_value=graph_engine) - node._run_single_iter = MagicMock(side_effect=ChildGraphAbortedError("quota exceeded")) - - events = list(node._run()) - - assert isinstance(events[0], IterationStartedEvent) - assert isinstance(events[-2], IterationFailedEvent) - assert events[-2].error == "quota exceeded" - assert isinstance(events[-1], StreamCompletedEvent) - assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED - assert events[-1].node_run_result.error == "quota exceeded" - node._create_graph_engine.assert_called_once() - node._run_single_iter.assert_called_once() - - -def test_iteration_run_merges_child_usage_before_failing_on_sequential_child_abort() -> None: - node = _build_iteration_node(error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR) - graph_engine = SimpleNamespace( - graph_runtime_state=SimpleNamespace( - variable_pool=build_test_variable_pool(), - llm_usage=_usage_with_tokens(7), - ) - ) - node._create_graph_engine = MagicMock(return_value=graph_engine) - node._run_single_iter = MagicMock(side_effect=ChildGraphAbortedError("quota exceeded")) - - events = list(node._run()) - - assert isinstance(events[-1], StreamCompletedEvent) - assert events[-1].node_run_result.llm_usage.total_tokens == 7 - assert node.graph_runtime_state.llm_usage.total_tokens == 7 - - -@pytest.mark.parametrize( - "error_handle_mode", - [ - ErrorHandleMode.CONTINUE_ON_ERROR, - ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, - ], -) -def test_iteration_run_fails_on_parallel_child_abort_regardless_of_error_mode( - error_handle_mode: ErrorHandleMode, -) -> None: - node = _build_iteration_node( - error_handle_mode=error_handle_mode, - is_parallel=True, - ) - blocking_engine = _AbortOnRequestGraphEngine(index=1, total_tokens=5) - aborting_engine = _build_immediate_abort_graph_engine( - index=0, - total_tokens=3, - wait_before_abort=blocking_engine.started, - ) - node._create_graph_engine = MagicMock( - side_effect=lambda index, item: {0: aborting_engine, 1: blocking_engine}[index] - ) - - events = list(node._run()) - - assert isinstance(events[0], IterationStartedEvent) - assert isinstance(events[-2], IterationFailedEvent) - assert events[-2].error == "quota exceeded" - assert isinstance(events[-1], StreamCompletedEvent) - assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED - assert events[-1].node_run_result.error == "quota exceeded" - assert events[-1].node_run_result.llm_usage.total_tokens == 8 - assert node.graph_runtime_state.llm_usage.total_tokens == 8 - assert blocking_engine.started.is_set() - assert blocking_engine.abort_requested.is_set() - assert blocking_engine.finished.is_set() - assert blocking_engine.abort_reason == "quota exceeded" diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index 82cc734274..bbfe350f7e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -2,8 +2,6 @@ from collections.abc import Mapping from typing import Any import pytest - -from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.nodes.iteration.exc import IterationGraphNotFoundError from graphon.nodes.iteration.iteration_node import IterationNode @@ -13,6 +11,8 @@ from graphon.runtime import ( GraphRuntimeState, VariablePool, ) + +from core.workflow.system_variables import default_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py deleted file mode 100644 index 41d7c3193d..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py +++ /dev/null @@ -1,67 +0,0 @@ -import time -from datetime import UTC, datetime - -import pytest - -from graphon.enums import BuiltinNodeTypes -from graphon.graph_events import NodeRunSucceededEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from graphon.nodes.iteration.iteration_node import IterationNode - - -def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Parallel Iteration", - iterator_selector=["start", "items"], - output_selector=["iteration", "output"], - is_parallel=True, - parallel_nums=2, - error_handle_mode=ErrorHandleMode.TERMINATED, - ) - node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new) - - def fake_execute_tracked_iteration_parallel( - *, - index: int, - item: object, - started_child_engines: dict[int, object], - started_child_engines_lock: object, - ): - _ = started_child_engines - _ = started_child_engines_lock - return ( - 0.1 + (index * 0.1), - [ - NodeRunSucceededEvent( - id=f"exec-{index}", - node_id=f"llm-{index}", - node_type=BuiltinNodeTypes.LLM, - start_at=datetime.now(UTC).replace(tzinfo=None), - ), - ], - f"output-{item}", - LLMUsage.empty_usage(), - ) - - node._execute_tracked_iteration_parallel = fake_execute_tracked_iteration_parallel - - outputs: list[object] = [] - iter_run_map: dict[str, float] = {} - usage_accumulator = [LLMUsage.empty_usage()] - - generator = node._execute_parallel_iterations( - iterator_list_value=["a", "b"], - outputs=outputs, - iter_run_map=iter_run_map, - usage_accumulator=usage_accumulator, - ) - - for _ in generator: - # Simulate a slow consumer replaying buffered events. - time.sleep(0.02) - - assert outputs == ["output-a", "output-b"] - assert iter_run_map["0"] == pytest.approx(0.1) - assert iter_run_map["1"] == pytest.approx(0.2) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index a6fca1bfb4..f8802138b5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -3,6 +3,9 @@ import uuid from unittest.mock import Mock import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.segments import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -16,9 +19,6 @@ from core.workflow.nodes.knowledge_index.protocols import ( SummaryIndexServiceProtocol, ) from core.workflow.system_variables import SystemVariableKey, build_system_variables -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables.segments import StringSegment from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 45e8ae7d20..ab64be59ad 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -3,6 +3,10 @@ import uuid from unittest.mock import Mock import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.nodes.knowledge_retrieval.entities import ( @@ -17,10 +21,6 @@ from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source from core.workflow.system_variables import build_system_variables -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import StringSegment from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index eca34f05be..fdf1706765 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,14 +1,14 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.nodes.list_operator.node import ListOperatorNode from graphon.runtime import GraphRuntimeState from graphon.variables import ArrayNumberSegment, ArrayStringSegment +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY + class TestListOperatorNode: """Comprehensive tests for ListOperatorNode.""" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py deleted file mode 100644 index 4f9ba0194a..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ /dev/null @@ -1,170 +0,0 @@ -import uuid -from typing import NamedTuple -from unittest import mock -from unittest.mock import MagicMock - -import httpx -import pytest - -from graphon.file import FileTransferMethod, FileType -from graphon.nodes.llm.file_saver import ( - FileSaverImpl, - _extract_content_type_and_extension, - _get_extension, - _validate_extension_override, -) -from graphon.nodes.protocols import ToolFileManagerProtocol - -_PNG_DATA = b"\x89PNG\r\n\x1a\n" - - -def _gen_id(): - return str(uuid.uuid4()) - - -class TestFileSaverImpl: - def test_save_binary_string(self, monkeypatch: pytest.MonkeyPatch): - file_type = FileType.IMAGE - mime_type = "image/png" - mock_tool_file = MagicMock() - mock_tool_file.id = _gen_id() - mock_tool_file.name = f"{_gen_id()}.png" - mock_tool_file.file_key = "test-file-key" - mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManagerProtocol) - mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file - file_reference = MagicMock() - file_reference_factory = MagicMock() - file_reference_factory.build_from_mapping.return_value = file_reference - http_client = MagicMock() - - file_saver = FileSaverImpl( - tool_file_manager=mocked_tool_file_manager, - file_reference_factory=file_reference_factory, - http_client=http_client, - ) - - file = file_saver.save_binary_string(_PNG_DATA, mime_type, file_type) - assert file is file_reference - - mocked_tool_file_manager.create_file_by_raw.assert_called_once_with( - file_binary=_PNG_DATA, - mimetype=mime_type, - ) - file_reference_factory.build_from_mapping.assert_called_once_with( - mapping={ - "type": file_type, - "transfer_method": FileTransferMethod.TOOL_FILE, - "filename": mock_tool_file.name, - "extension": ".png", - "mime_type": mime_type, - "size": len(_PNG_DATA), - "tool_file_id": mock_tool_file.id, - "related_id": mock_tool_file.id, - "storage_key": mock_tool_file.file_key, - } - ) - - def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch): - _TEST_URL = "https://example.com/image.png" - mock_request = httpx.Request("GET", _TEST_URL) - mock_response = httpx.Response( - status_code=401, - request=mock_request, - ) - http_client = MagicMock() - http_client.get.return_value = mock_response - - file_saver = FileSaverImpl( - tool_file_manager=MagicMock(), - file_reference_factory=MagicMock(), - http_client=http_client, - ) - - with pytest.raises(httpx.HTTPStatusError) as exc: - file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) - http_client.get.assert_called_once_with(_TEST_URL) - assert exc.value.response.status_code == 401 - - def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch): - _TEST_URL = "https://example.com/image.png" - mime_type = "image/png" - - mock_request = httpx.Request("GET", _TEST_URL) - mock_response = httpx.Response( - status_code=200, - content=b"test-data", - headers={"Content-Type": mime_type}, - request=mock_request, - ) - http_client = MagicMock() - http_client.get.return_value = mock_response - - file_saver = FileSaverImpl( - tool_file_manager=MagicMock(), - file_reference_factory=MagicMock(), - http_client=http_client, - ) - expected_file = MagicMock() - mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=expected_file) - monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string) - - file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) - mock_save_binary_string.assert_called_once_with( - mock_response.content, - mime_type, - FileType.IMAGE, - extension_override=".png", - ) - assert file is expected_file - - -def test_validate_extension_override(): - class TestCase(NamedTuple): - extension_override: str | None - expected: str | None - - cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"] - - for valid_ext_override in [None, "", ".png", ".tar.gz"]: - assert valid_ext_override == _validate_extension_override(valid_ext_override) - - for invalid_ext_override in ["png", "tar.gz"]: - with pytest.raises(ValueError) as exc: - _validate_extension_override(invalid_ext_override) - - -class TestExtractContentTypeAndExtension: - def test_with_both_content_type_and_extension(self): - content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png") - assert content_type == "image/png" - assert extension == ".png" - - def test_url_with_file_extension(self): - for content_type in [None, ""]: - content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type) - assert content_type == "image/png" - assert extension == ".png" - - def test_response_with_content_type(self): - content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png") - assert content_type == "image/png" - assert extension == ".png" - - def test_no_content_type_and_no_extension(self): - for content_type in [None, ""]: - content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type) - assert content_type == "application/octet-stream" - assert extension == ".bin" - - -class TestGetExtension: - def test_with_extension_override(self): - mime_type = "image/png" - for override in [".jpg", ""]: - extension = _get_extension(mime_type, override) - assert extension == override - - def test_without_extension_override(self): - mime_type = "image/png" - extension = _get_extension(mime_type) - assert extension == ".png" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index dfc982f49c..c784f805c0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -1,10 +1,7 @@ from unittest import mock import pytest - -from core.model_manager import ModelInstance -from graphon.file import FileTransferMethod, FileType -from graphon.file.models import File +from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities import ( ImagePromptMessageContent, PromptMessageRole, @@ -36,6 +33,8 @@ from graphon.nodes.llm.exc import ( from graphon.runtime import VariablePool from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment +from core.model_manager import ModelInstance + def _build_model_schema( *, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index a2fbc50392..a215e9d350 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -4,19 +4,6 @@ from collections.abc import Sequence from unittest import mock import pytest - -from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom -from core.app.llm.model_access import ( - DifyCredentialsProvider, - DifyModelFactory, - build_dify_model_access, - fetch_model_config, -) -from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.common_entities import I18nObject @@ -79,6 +66,19 @@ from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment + +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom +from core.app.llm.model_access import ( + DifyCredentialsProvider, + DifyModelFactory, + build_dify_model_access, + fetch_model_config, +) +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.workflow.system_variables import default_system_variables from models.provider import ProviderType from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py deleted file mode 100644 index af1cff4e81..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py +++ /dev/null @@ -1,25 +0,0 @@ -from collections.abc import Mapping, Sequence - -from pydantic import BaseModel, Field - -from graphon.file import File -from graphon.model_runtime.entities.message_entities import PromptMessage -from graphon.model_runtime.entities.model_entities import ModelFeature -from graphon.nodes.llm.entities import LLMNodeChatModelMessage - - -class LLMNodeTestScenario(BaseModel): - """Test scenario for LLM node testing.""" - - description: str = Field(..., description="Description of the test scenario") - sys_query: str = Field(..., description="User query input") - sys_files: Sequence[File] = Field(default_factory=list, description="List of user files") - vision_enabled: bool = Field(default=False, description="Whether vision is enabled") - vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") - features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") - window_size: int = Field(..., description="Window size for memory") - prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") - file_variables: Mapping[str, File | Sequence[File]] = Field( - default_factory=dict, description="List of file variables" - ) - expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing") diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py deleted file mode 100644 index ccf1077838..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py +++ /dev/null @@ -1,27 +0,0 @@ -from graphon.nodes.parameter_extractor.entities import ParameterConfig -from graphon.variables.types import SegmentType - - -class TestParameterConfig: - def test_select_type(self): - data = { - "name": "yes_or_no", - "type": "select", - "options": ["yes", "no"], - "description": "a simple select made of `yes` and `no`", - "required": True, - } - - pc = ParameterConfig.model_validate(data) - assert pc.type == SegmentType.STRING - assert pc.options == data["options"] - - def test_validate_bool_type(self): - data = { - "name": "boolean", - "type": "bool", - "description": "a simple boolean parameter", - "required": True, - } - pc = ParameterConfig.model_validate(data) - assert pc.type == SegmentType.BOOLEAN diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index 8f8ec49f14..1c362a0a03 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -6,8 +6,6 @@ from dataclasses import dataclass from typing import Any import pytest - -from factories.variable_factory import build_segment_with_type from graphon.model_runtime.entities import LLMMode from graphon.nodes.llm import ModelConfig, VisionConfig from graphon.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData @@ -20,6 +18,8 @@ from graphon.nodes.parameter_extractor.exc import ( from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from graphon.variables.types import SegmentType +from factories.variable_factory import build_segment_with_type + @dataclass class ValidTestCase: diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py deleted file mode 100644 index 01878ed692..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py +++ /dev/null @@ -1,225 +0,0 @@ -import pytest -from pydantic import ValidationError - -from graphon.enums import ErrorStrategy -from graphon.nodes.template_transform.entities import TemplateTransformNodeData - - -class TestTemplateTransformNodeData: - """Test suite for TemplateTransformNodeData entity.""" - - def test_valid_template_transform_node_data(self): - """Test creating valid TemplateTransformNodeData.""" - data = { - "title": "Template Transform", - "desc": "Transform data using Jinja2 template", - "variables": [ - {"variable": "name", "value_selector": ["sys", "user_name"]}, - {"variable": "age", "value_selector": ["sys", "user_age"]}, - ], - "template": "Hello {{ name }}, you are {{ age }} years old!", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.title == "Template Transform" - assert node_data.desc == "Transform data using Jinja2 template" - assert len(node_data.variables) == 2 - assert node_data.variables[0].variable == "name" - assert node_data.variables[0].value_selector == ["sys", "user_name"] - assert node_data.variables[1].variable == "age" - assert node_data.variables[1].value_selector == ["sys", "user_age"] - assert node_data.template == "Hello {{ name }}, you are {{ age }} years old!" - - def test_template_transform_node_data_with_empty_variables(self): - """Test TemplateTransformNodeData with no variables.""" - data = { - "title": "Static Template", - "variables": [], - "template": "This is a static template with no variables.", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.title == "Static Template" - assert len(node_data.variables) == 0 - assert node_data.template == "This is a static template with no variables." - - def test_template_transform_node_data_with_complex_template(self): - """Test TemplateTransformNodeData with complex Jinja2 template.""" - data = { - "title": "Complex Template", - "variables": [ - {"variable": "items", "value_selector": ["sys", "item_list"]}, - {"variable": "total", "value_selector": ["sys", "total_count"]}, - ], - "template": ( - "{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}. Total: {{ total }}" - ), - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.title == "Complex Template" - assert len(node_data.variables) == 2 - assert "{% for item in items %}" in node_data.template - assert "{{ total }}" in node_data.template - - def test_template_transform_node_data_with_error_strategy(self): - """Test TemplateTransformNodeData with error handling strategy.""" - data = { - "title": "Template with Error Handling", - "variables": [{"variable": "value", "value_selector": ["sys", "input"]}], - "template": "{{ value }}", - "error_strategy": "fail-branch", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.error_strategy == ErrorStrategy.FAIL_BRANCH - - def test_template_transform_node_data_with_retry_config(self): - """Test TemplateTransformNodeData with retry configuration.""" - data = { - "title": "Template with Retry", - "variables": [{"variable": "data", "value_selector": ["sys", "data"]}], - "template": "{{ data }}", - "retry_config": {"enabled": True, "max_retries": 3, "retry_interval": 1000}, - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.retry_config.enabled is True - assert node_data.retry_config.max_retries == 3 - assert node_data.retry_config.retry_interval == 1000 - - def test_template_transform_node_data_missing_required_fields(self): - """Test that missing required fields raises ValidationError.""" - data = { - "title": "Incomplete Template", - # Missing 'variables' and 'template' - } - - with pytest.raises(ValidationError) as exc_info: - TemplateTransformNodeData.model_validate(data) - - errors = exc_info.value.errors() - assert len(errors) >= 2 - error_fields = {error["loc"][0] for error in errors} - assert "variables" in error_fields - assert "template" in error_fields - - def test_template_transform_node_data_invalid_variable_selector(self): - """Test that invalid variable selector format raises ValidationError.""" - data = { - "title": "Invalid Variable", - "variables": [ - {"variable": "name", "value_selector": "invalid_format"} # Should be list - ], - "template": "{{ name }}", - } - - with pytest.raises(ValidationError): - TemplateTransformNodeData.model_validate(data) - - def test_template_transform_node_data_with_default_value_dict(self): - """Test TemplateTransformNodeData with default value dictionary.""" - data = { - "title": "Template with Defaults", - "variables": [ - {"variable": "name", "value_selector": ["sys", "user_name"]}, - {"variable": "greeting", "value_selector": ["sys", "greeting"]}, - ], - "template": "{{ greeting }} {{ name }}!", - "default_value_dict": {"greeting": "Hello", "name": "Guest"}, - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.default_value_dict == {"greeting": "Hello", "name": "Guest"} - - def test_template_transform_node_data_with_nested_selectors(self): - """Test TemplateTransformNodeData with nested variable selectors.""" - data = { - "title": "Nested Selectors", - "variables": [ - {"variable": "user_info", "value_selector": ["sys", "user", "profile", "name"]}, - {"variable": "settings", "value_selector": ["sys", "config", "app", "theme"]}, - ], - "template": "User: {{ user_info }}, Theme: {{ settings }}", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert len(node_data.variables) == 2 - assert node_data.variables[0].value_selector == ["sys", "user", "profile", "name"] - assert node_data.variables[1].value_selector == ["sys", "config", "app", "theme"] - - def test_template_transform_node_data_with_multiline_template(self): - """Test TemplateTransformNodeData with multiline template.""" - data = { - "title": "Multiline Template", - "variables": [ - {"variable": "title", "value_selector": ["sys", "title"]}, - {"variable": "content", "value_selector": ["sys", "content"]}, - ], - "template": """ -# {{ title }} - -{{ content }} - ---- -Generated by Template Transform Node - """, - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert "# {{ title }}" in node_data.template - assert "{{ content }}" in node_data.template - assert "Generated by Template Transform Node" in node_data.template - - def test_template_transform_node_data_serialization(self): - """Test that TemplateTransformNodeData can be serialized and deserialized.""" - original_data = { - "title": "Serialization Test", - "desc": "Test serialization", - "variables": [{"variable": "test", "value_selector": ["sys", "test"]}], - "template": "{{ test }}", - } - - node_data = TemplateTransformNodeData.model_validate(original_data) - serialized = node_data.model_dump() - deserialized = TemplateTransformNodeData.model_validate(serialized) - - assert deserialized.title == node_data.title - assert deserialized.desc == node_data.desc - assert len(deserialized.variables) == len(node_data.variables) - assert deserialized.template == node_data.template - - def test_template_transform_node_data_with_special_characters(self): - """Test TemplateTransformNodeData with special characters in template.""" - data = { - "title": "Special Characters", - "variables": [{"variable": "text", "value_selector": ["sys", "input"]}], - "template": "Special: {{ text }} | Symbols: @#$%^&*() | Unicode: 你好 🎉", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert "@#$%^&*()" in node_data.template - assert "你好" in node_data.template - assert "🎉" in node_data.template - - def test_template_transform_node_data_empty_template(self): - """Test TemplateTransformNodeData with empty template string.""" - data = { - "title": "Empty Template", - "variables": [], - "template": "", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.template == "" - assert len(node_data.variables) == 0 diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index bc44ececd8..d86e0efe02 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -1,8 +1,6 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.nodes.base.entities import VariableSelector @@ -10,6 +8,8 @@ from graphon.nodes.template_transform.entities import TemplateTransformNodeData from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from graphon.runtime import GraphRuntimeState from graphon.template_rendering import TemplateRenderError + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py index 636237e56e..bd22a8e318 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -1,14 +1,14 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.nodes.base.entities import VariableSelector from graphon.nodes.template_transform.template_transform_node import ( DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH, TemplateTransformNode, ) from graphon.runtime import GraphRuntimeState + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params from .template_transform_node_spec import TestTemplateTransformNode # noqa: F401 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 0522dd9d14..e11ebf6eb8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -1,16 +1,16 @@ from collections.abc import Mapping import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_runtime import resolve_dify_run_context -from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams from graphon.entities.base_node_data import BaseNodeData from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.node import Node from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_runtime import resolve_dify_run_context +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 87ec2d5bce..555ff0c945 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -4,8 +4,6 @@ from unittest.mock import Mock, patch import pandas as pd import pytest from docx.oxml.text.paragraph import CT_P - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod @@ -21,6 +19,8 @@ from graphon.nodes.document_extractor.node import ( from graphon.variables import ArrayFileSegment from graphon.variables.segments import ArrayStringSegment from graphon.variables.variables import StringVariable + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 782750e02e..1b14f0ab13 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -3,11 +3,6 @@ import uuid from unittest.mock import MagicMock, Mock import pytest - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.graph import Graph @@ -16,6 +11,11 @@ from graphon.nodes.if_else.if_else_node import IfElseNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.utils.condition.entities import Condition, SubCondition, SubVariableCondition from graphon.variables import ArrayFileSegment + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index b217e4e8e7..d28c3e01e5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -1,8 +1,6 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.nodes.list_operator.entities import ( @@ -18,6 +16,8 @@ from graphon.nodes.list_operator.exc import InvalidKeyError from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func from graphon.variables import ArrayFileSegment +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom + @pytest.fixture def list_operator_node(): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py deleted file mode 100644 index d613ba154a..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py +++ /dev/null @@ -1,150 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph_events import GraphRunAbortedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.node_events import LoopFailedEvent, LoopStartedEvent, StreamCompletedEvent -from graphon.nodes.loop.entities import LoopNodeData -from graphon.nodes.loop.loop_node import LoopNode -from tests.workflow_test_utils import build_test_variable_pool - - -def _usage_with_tokens(total_tokens: int) -> LLMUsage: - usage = LLMUsage.empty_usage() - usage.total_tokens = total_tokens - return usage - - -def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: - seen_configs: list[object] = [] - original_validate_python = NodeConfigDictAdapter.validate_python - - def record_validate_python(value: object): - seen_configs.append(value) - return original_validate_python(value) - - monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) - - child_node_config = { - "id": "answer-node", - "data": { - "type": "answer", - "title": "Answer", - "answer": "", - "loop_id": "loop-node", - }, - } - - LoopNode._extract_variable_selector_to_variable_mapping( - graph_config={ - "nodes": [ - { - "id": "loop-node", - "data": { - "type": "loop", - "title": "Loop", - "loop_count": 1, - "break_conditions": [], - "logical_operator": "and", - }, - }, - child_node_config, - ], - "edges": [], - }, - node_id="loop-node", - node_data=LoopNodeData( - title="Loop", - loop_count=1, - break_conditions=[], - logical_operator="and", - ), - ) - - assert seen_configs == [child_node_config] - - -def test_run_single_loop_raises_on_child_abort_event() -> None: - node = LoopNode.__new__(LoopNode) - node._node_id = "loop-node" - node._node_data = LoopNodeData( - title="Loop", - loop_count=1, - break_conditions=[], - logical_operator="and", - start_node_id="child-start", - ) - - graph_engine = SimpleNamespace( - run=lambda: iter([GraphRunAbortedEvent(reason="quota exceeded")]), - ) - - with pytest.raises(RuntimeError, match="quota exceeded"): - list(node._run_single_loop(graph_engine=graph_engine, current_index=0)) - - -def test_loop_run_fails_on_child_abort_and_stops_subsequent_rounds() -> None: - node = LoopNode.__new__(LoopNode) - node._node_id = "loop-node" - node._node_data = LoopNodeData( - title="Loop", - loop_count=2, - break_conditions=[], - logical_operator="and", - start_node_id="child-start", - ) - node.graph_config = {"nodes": [], "edges": []} - node.graph_runtime_state = SimpleNamespace( - variable_pool=build_test_variable_pool(), - llm_usage=LLMUsage.empty_usage(), - ) - - aborting_engine = SimpleNamespace( - graph_runtime_state=SimpleNamespace(outputs={}, llm_usage=LLMUsage.empty_usage()), - ) - create_graph_engine = MagicMock(return_value=aborting_engine) - node._create_graph_engine = create_graph_engine - node._run_single_loop = lambda *, graph_engine, current_index: (_ for _ in ()).throw(RuntimeError("quota exceeded")) - - events = list(node._run()) - - assert isinstance(events[0], LoopStartedEvent) - assert isinstance(events[1], LoopFailedEvent) - assert events[1].error == "quota exceeded" - assert isinstance(events[2], StreamCompletedEvent) - assert events[2].node_run_result.status == WorkflowNodeExecutionStatus.FAILED - assert events[2].node_run_result.error == "quota exceeded" - create_graph_engine.assert_called_once() - - -def test_loop_run_merges_child_usage_before_failing_on_child_abort() -> None: - node = LoopNode.__new__(LoopNode) - node._node_id = "loop-node" - node._node_data = LoopNodeData( - title="Loop", - loop_count=1, - break_conditions=[], - logical_operator="and", - start_node_id="child-start", - ) - node.graph_config = {"nodes": [], "edges": []} - node.graph_runtime_state = SimpleNamespace( - variable_pool=build_test_variable_pool(), - llm_usage=LLMUsage.empty_usage(), - ) - - aborting_engine = SimpleNamespace( - graph_runtime_state=SimpleNamespace(outputs={}, llm_usage=_usage_with_tokens(7)), - ) - node._create_graph_engine = MagicMock(return_value=aborting_engine) - node._run_single_loop = lambda *, graph_engine, current_index: (_ for _ in ()).throw(RuntimeError("quota exceeded")) - - events = list(node._run()) - - assert isinstance(events[-1], StreamCompletedEvent) - assert events[-1].node_run_result.llm_usage.total_tokens == 7 - assert node.graph_runtime_state.llm_usage.total_tokens == 7 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py deleted file mode 100644 index efbf786a55..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py +++ /dev/null @@ -1,126 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock - -from graphon.model_runtime.entities import ImagePromptMessageContent -from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory -from graphon.nodes.protocols import HttpClientProtocol -from graphon.nodes.question_classifier import ( - QuestionClassifierNode, - QuestionClassifierNodeData, -) -from graphon.template_rendering import Jinja2TemplateRenderer -from tests.workflow_test_utils import build_test_graph_init_params - - -def test_init_question_classifier_node_data(): - data = { - "title": "test classifier node", - "query_variable_selector": ["id", "name"], - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, - "classes": [{"id": "1", "name": "class 1"}], - "instruction": "This is a test instruction", - "memory": { - "role_prefix": {"user": "Human:", "assistant": "AI:"}, - "window": {"enabled": True, "size": 5}, - "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", - }, - "vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}}, - } - - node_data = QuestionClassifierNodeData.model_validate(data) - - assert node_data.query_variable_selector == ["id", "name"] - assert node_data.model.provider == "openai" - assert node_data.classes[0].id == "1" - assert node_data.instruction == "This is a test instruction" - assert node_data.memory is not None - assert node_data.memory.role_prefix is not None - assert node_data.memory.role_prefix.user == "Human:" - assert node_data.memory.role_prefix.assistant == "AI:" - assert node_data.memory.window.enabled == True - assert node_data.memory.window.size == 5 - assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" - assert node_data.vision.enabled == True - assert node_data.vision.configs.variable_selector == ["image"] - assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.LOW - - -def test_init_question_classifier_node_data_without_vision_config(): - data = { - "title": "test classifier node", - "query_variable_selector": ["id", "name"], - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, - "classes": [{"id": "1", "name": "class 1"}], - "instruction": "This is a test instruction", - "memory": { - "role_prefix": {"user": "Human:", "assistant": "AI:"}, - "window": {"enabled": True, "size": 5}, - "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", - }, - } - - node_data = QuestionClassifierNodeData.model_validate(data) - - assert node_data.query_variable_selector == ["id", "name"] - assert node_data.model.provider == "openai" - assert node_data.classes[0].id == "1" - assert node_data.instruction == "This is a test instruction" - assert node_data.memory is not None - assert node_data.memory.role_prefix is not None - assert node_data.memory.role_prefix.user == "Human:" - assert node_data.memory.role_prefix.assistant == "AI:" - assert node_data.memory.window.enabled == True - assert node_data.memory.window.size == 5 - assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" - assert node_data.vision.enabled == False - assert node_data.vision.configs.variable_selector == ["sys", "files"] - assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH - - -def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch): - node_data = QuestionClassifierNodeData.model_validate( - { - "title": "test classifier node", - "query_variable_selector": ["id", "name"], - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, - "classes": [{"id": "1", "name": "class 1"}], - "instruction": "This is a test instruction", - } - ) - template_renderer = MagicMock(spec=Jinja2TemplateRenderer) - node = QuestionClassifierNode( - id="node-id", - config={"id": "node-id", "data": node_data.model_dump(mode="json")}, - graph_init_params=build_test_graph_init_params( - workflow_id="workflow-id", - graph_config={}, - tenant_id="tenant-id", - app_id="app-id", - user_id="user-id", - ), - graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()), - credentials_provider=MagicMock(spec=CredentialsProvider), - model_factory=MagicMock(spec=ModelFactory), - model_instance=MagicMock(), - http_client=MagicMock(spec=HttpClientProtocol), - llm_file_saver=MagicMock(), - template_renderer=template_renderer, - ) - fetch_prompt_messages = MagicMock(return_value=([], None)) - monkeypatch.setattr( - "graphon.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages", - fetch_prompt_messages, - ) - monkeypatch.setattr( - "graphon.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema", - MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])), - ) - - node._calculate_rest_token( - node_data=node_data, - query="hello", - model_instance=MagicMock(stop=(), parameters={}), - context="", - ) - - assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 543f9878de..833c303052 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -2,16 +2,16 @@ import json import time import pytest -from pydantic import ValidationError as PydanticValidationError - -from core.workflow.system_variables import build_system_variables -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState from graphon.variables import build_segment, segment_to_variable from graphon.variables.input_entities import VariableEntity, VariableEntityType from graphon.variables.variables import Variable +from pydantic import ValidationError as PydanticValidationError + +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index c806181340..1587014802 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -8,14 +8,14 @@ from typing import TYPE_CHECKING, Any from unittest.mock import MagicMock import pytest - -from core.workflow.system_variables import build_system_variables from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.segments import ArrayFileSegment + +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if TYPE_CHECKING: # pragma: no cover - imported for type checking only diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py index 438af211f3..c4dfc5a179 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -6,6 +6,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType +from graphon.nodes.tool.exc import ToolRuntimeInvocationError +from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from graphon.runtime import VariablePool from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError @@ -17,11 +22,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.system_variables import build_system_variables -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType -from graphon.nodes.tool.exc import ToolRuntimeInvocationError -from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage -from graphon.runtime import VariablePool from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index c8ddc53284..952e798430 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -1,12 +1,13 @@ from collections.abc import Mapping -from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode -from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.runtime import GraphRuntimeState + +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py deleted file mode 100644 index fabc8df73e..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ /dev/null @@ -1,312 +0,0 @@ -import time -import uuid -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool -from graphon.entities import GraphInitParams -from graphon.graph import Graph -from graphon.graph_events.node import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent -from graphon.nodes.variable_assigner.common import helpers as common_helpers -from graphon.nodes.variable_assigner.v1 import VariableAssignerNode -from graphon.nodes.variable_assigner.v1.node_data import WriteMode -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import ArrayStringVariable, StringVariable - -DEFAULT_NODE_ID = "node_id" - - -def _build_variable_pool( - *, - conversation_id: str, - conversation_variables: list[StringVariable | ArrayStringVariable], -) -> VariablePool: - variable_pool = VariablePool() - add_variables_to_pool( - variable_pool, - build_bootstrap_variables( - system_variables=build_system_variables(conversation_id=conversation_id), - conversation_variables=conversation_variables, - ), - ) - return variable_pool - - -def test_overwrite_string_variable(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "over-write", - "input_variable_selector": ["node_id", "test_string_variable"], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = StringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value="the first value", - ) - - input_variable = StringVariable( - id=str(uuid4()), - name="test_string_variable", - value="the second value", - ) - conversation_id = str(uuid.uuid4()) - - # construct variable pool - variable_pool = _build_variable_pool( - conversation_id=conversation_id, - conversation_variables=[conversation_variable], - ) - - variable_pool.add( - [DEFAULT_NODE_ID, input_variable.name], - input_variable, - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.OVER_WRITE, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - events = list(node.run()) - updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) - succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) - updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) - assert updated_variables is not None - assert updated_variables[0].name == conversation_variable.name - assert updated_variables[0].new_value == input_variable.value - assert updated_event.variable.value == "the second value" - assert tuple(updated_event.variable.selector) == ("conversation", conversation_variable.name) - - -def test_append_variable_to_array(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "append", - "input_variable_selector": ["node_id", "test_string_variable"], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["the first value"], - ) - - input_variable = StringVariable( - id=str(uuid4()), - name="test_string_variable", - value="the second value", - ) - conversation_id = str(uuid.uuid4()) - - variable_pool = _build_variable_pool( - conversation_id=conversation_id, - conversation_variables=[conversation_variable], - ) - variable_pool.add( - [DEFAULT_NODE_ID, input_variable.name], - input_variable, - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.APPEND, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - events = list(node.run()) - updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) - succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) - updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) - assert updated_variables is not None - assert updated_variables[0].name == conversation_variable.name - assert updated_variables[0].new_value == ["the first value", "the second value"] - assert updated_event.variable.value == ["the first value", "the second value"] - - -def test_clear_array(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "clear", - "input_variable_selector": [], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["the first value"], - ) - - conversation_id = str(uuid.uuid4()) - variable_pool = _build_variable_pool( - conversation_id=conversation_id, - conversation_variables=[conversation_variable], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.CLEAR, - "input_variable_selector": [], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - events = list(node.run()) - updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) - succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) - updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) - assert updated_variables is not None - assert updated_variables[0].name == conversation_variable.name - assert updated_variables[0].new_value == [] - assert updated_event.variable.value == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py deleted file mode 100644 index 8b13789179..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py deleted file mode 100644 index 9ac8bbe9c2..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py +++ /dev/null @@ -1,22 +0,0 @@ -from graphon.nodes.variable_assigner.v2.enums import Operation -from graphon.nodes.variable_assigner.v2.helpers import is_input_value_valid -from graphon.variables import SegmentType - - -def test_is_input_value_valid_overwrite_array_string(): - # Valid cases - assert is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["hello", "world"] - ) - assert is_input_value_valid(variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[]) - - # Invalid cases - assert not is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value="not an array" - ) - assert not is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[1, 2, 3] - ) - assert not is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["valid", 123, "invalid"] - ) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py deleted file mode 100644 index 53346c4a90..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ /dev/null @@ -1,430 +0,0 @@ -import time -import uuid -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool -from graphon.entities import GraphInitParams -from graphon.graph import Graph -from graphon.graph_events import NodeRunVariableUpdatedEvent -from graphon.nodes.variable_assigner.v2 import VariableAssignerNode -from graphon.nodes.variable_assigner.v2.enums import InputType, Operation -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import ArrayStringVariable - -DEFAULT_NODE_ID = "node_id" - - -def _build_variable_pool(*, conversation_variables: list[ArrayStringVariable]) -> VariablePool: - variable_pool = VariablePool() - add_variables_to_pool( - variable_pool, - build_bootstrap_variables( - system_variables=build_system_variables(conversation_id="conversation_id"), - conversation_variables=conversation_variables, - ), - ) - return variable_pool - - -def test_handle_item_directly(): - """Test the _handle_item method directly for remove operations.""" - # Create variables - variable1 = ArrayStringVariable( - id=str(uuid4()), - name="test_variable1", - value=["first", "second", "third"], - ) - - variable2 = ArrayStringVariable( - id=str(uuid4()), - name="test_variable2", - value=["first", "second", "third"], - ) - - # Create a mock class with just the _handle_item method - class MockNode: - def _handle_item(self, *, variable, operation, value): - match operation: - case Operation.REMOVE_FIRST: - if not variable.value: - return variable.value - return variable.value[1:] - case Operation.REMOVE_LAST: - if not variable.value: - return variable.value - return variable.value[:-1] - - node = MockNode() - - # Test remove-first - result1 = node._handle_item( - variable=variable1, - operation=Operation.REMOVE_FIRST, - value=None, - ) - - # Test remove-last - result2 = node._handle_item( - variable=variable2, - operation=Operation.REMOVE_LAST, - value=None, - ) - - # Check the results - assert result1 == ["second", "third"] - assert result2 == ["first", "second"] - - -def test_remove_first_from_array(): - """Test removing the first element from an array.""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["first", "second", "third"], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_FIRST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - # Run the node - result = list(node.run()) - - updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) - assert updated_event.variable.value == ["second", "third"] - - -def test_remove_last_from_array(): - """Test removing the last element from an array.""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["first", "second", "third"], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_LAST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - result = list(node.run()) - updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) - assert updated_event.variable.value == ["first", "second"] - - -def test_remove_first_from_empty_array(): - """Test removing the first element from an empty array (should do nothing).""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=[], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_FIRST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - result = list(node.run()) - updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) - assert updated_event.variable.value == [] - - -def test_remove_last_from_empty_array(): - """Test removing the last element from an empty array (should do nothing).""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=[], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_LAST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - result = list(node.run()) - updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) - assert updated_event.variable.value == [] - - -def test_node_factory_creates_variable_assigner_node(): - graph_config = { - "edges": [], - "nodes": [ - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - variable_pool = _build_variable_pool(conversation_variables=[]) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - - node = node_factory.create_node(graph_config["nodes"][0]) - - assert isinstance(node, VariableAssignerNode) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index 617554ee17..f1132af02b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -1,4 +1,5 @@ import pytest +from graphon.entities.exc import BaseNodeError from core.workflow.nodes.trigger_webhook.exc import ( WebhookConfigError, @@ -6,7 +7,6 @@ from core.workflow.nodes.trigger_webhook.exc import ( WebhookNotFoundError, WebhookTimeoutError, ) -from graphon.entities.exc import BaseNodeError def test_webhook_node_error_inheritance(): diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 6fbd26131d..cccd3fb676 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -8,6 +8,10 @@ when passing files to downstream LLM nodes. from unittest.mock import Mock, patch +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.nodes.trigger_webhook.entities import ( ContentType, @@ -17,10 +21,6 @@ from core.workflow.nodes.trigger_webhook.entities import ( ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode from core.workflow.system_variables import default_system_variables -from graphon.entities.graph_init_params import GraphInitParams -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.runtime.graph_runtime_state import GraphRuntimeState -from graphon.runtime.variable_pool import VariablePool from tests.workflow_test_utils import build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 9f954b2090..34c66a4f9f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -1,6 +1,11 @@ from unittest.mock import patch import pytest +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import FileVariable, StringVariable from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE @@ -13,12 +18,6 @@ from core.workflow.nodes.trigger_webhook.entities import ( ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode from core.workflow.system_variables import default_system_variables -from graphon.entities.graph_init_params import GraphInitParams -from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType -from graphon.runtime.graph_runtime_state import GraphRuntimeState -from graphon.runtime.variable_pool import VariablePool -from graphon.variables import FileVariable, StringVariable from tests.workflow_test_utils import build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/test_enums.py b/api/tests/unit_tests/core/workflow/test_enums.py deleted file mode 100644 index 453e0a8502..0000000000 --- a/api/tests/unit_tests/core/workflow/test_enums.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Tests for workflow pause related enums and constants.""" - -from graphon.enums import ( - WorkflowExecutionStatus, -) - - -class TestWorkflowExecutionStatus: - """Test WorkflowExecutionStatus enum.""" - - def test_is_ended_method(self): - """Test is_ended method for different statuses.""" - # Test ended statuses - ended_statuses = [ - WorkflowExecutionStatus.SUCCEEDED, - WorkflowExecutionStatus.FAILED, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED, - WorkflowExecutionStatus.STOPPED, - ] - - for status in ended_statuses: - assert status.is_ended(), f"{status} should be considered ended" - - # Test non-ended statuses - non_ended_statuses = [ - WorkflowExecutionStatus.SCHEDULED, - WorkflowExecutionStatus.RUNNING, - WorkflowExecutionStatus.PAUSED, - ] - - for status in non_ended_statuses: - assert not status.is_ended(), f"{status} should not be considered ended" - - def test_ended_values(self): - """Test ended_values returns the expected status values.""" - assert set(WorkflowExecutionStatus.ended_values()) == { - WorkflowExecutionStatus.SUCCEEDED.value, - WorkflowExecutionStatus.FAILED.value, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value, - WorkflowExecutionStatus.STOPPED.value, - } diff --git a/api/tests/unit_tests/core/workflow/test_human_input_compat.py b/api/tests/unit_tests/core/workflow/test_human_input_compat.py index 0623800b30..cd41c43e4a 100644 --- a/api/tests/unit_tests/core/workflow/test_human_input_compat.py +++ b/api/tests/unit_tests/core/workflow/test_human_input_compat.py @@ -1,5 +1,6 @@ from types import SimpleNamespace +from graphon.enums import BuiltinNodeTypes from pydantic import BaseModel from core.workflow.human_input_compat import ( @@ -15,7 +16,6 @@ from core.workflow.human_input_compat import ( normalize_node_data_for_graph, parse_human_input_delivery_methods, ) -from graphon.enums import BuiltinNodeTypes def test_email_delivery_config_helpers_render_and_sanitize_text() -> None: diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index 1db848a010..bc0b339fec 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -2,15 +2,15 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch, sentinel import pytest +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.code.entities import CodeLanguage +from graphon.variables.segments import StringSegment from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext, InvokeFrom, UserFrom from core.workflow import node_factory from core.workflow import template_rendering as workflow_template_rendering from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.code.entities import CodeLanguage -from graphon.variables.segments import StringSegment def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None: diff --git a/api/tests/unit_tests/core/workflow/test_node_runtime.py b/api/tests/unit_tests/core/workflow/test_node_runtime.py index 71a2afb28a..4f9c1dad59 100644 --- a/api/tests/unit_tests/core/workflow/test_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/test_node_runtime.py @@ -2,6 +2,10 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, sentinel import pytest +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.nodes.human_input.entities import HumanInputNodeData from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext, InvokeFrom, UserFrom from core.llm_generator.output_parser.errors import OutputParserError @@ -26,10 +30,6 @@ from core.workflow.node_runtime import ( build_dify_llm_file_saver, resolve_dify_run_context, ) -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.nodes.human_input.entities import HumanInputNodeData from tests.workflow_test_utils import build_test_run_context diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py index 72a0557b7c..05ea3dc311 100644 --- a/api/tests/unit_tests/core/workflow/test_system_variable.py +++ b/api/tests/unit_tests/core/workflow/test_system_variable.py @@ -1,14 +1,14 @@ from types import SimpleNamespace +from graphon.file import File, FileTransferMethod, FileType +from graphon.nodes import BuiltinNodeTypes + from core.workflow.system_variables import ( build_system_variables, default_system_variables, get_node_creation_preload_selectors, system_variables_to_mapping, ) -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File -from graphon.nodes import BuiltinNodeTypes def test_build_system_variables_normalizes_workflow_execution_id(): diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index dddd6eb00c..e7b2b2914a 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -2,15 +2,6 @@ import uuid from collections import defaultdict import pytest - -from core.workflow.system_variables import build_system_variables, system_variables_to_mapping -from core.workflow.variable_pool_initializer import add_variables_to_pool -from core.workflow.variable_prefixes import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, - SYSTEM_VARIABLE_NODE_ID, -) -from factories.variable_factory import build_segment, segment_to_variable from graphon.file import File, FileTransferMethod, FileType from graphon.runtime import VariablePool from graphon.variables import FileSegment, StringSegment @@ -36,6 +27,15 @@ from graphon.variables.variables import ( Variable, ) +from core.workflow.system_variables import build_system_variables, system_variables_to_mapping +from core.workflow.variable_pool_initializer import add_variables_to_pool +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) +from factories.variable_factory import build_segment, segment_to_variable + @pytest.fixture def pool(): diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 4ae6ed1659..d8361d06c4 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -1,6 +1,12 @@ from types import SimpleNamespace import pytest +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.file import File, FileTransferMethod, FileType +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.runtime import VariablePool +from graphon.variables.variables import StringVariable from configs import dify_config from core.helper.code_executor.code_executor import CodeLanguage @@ -10,13 +16,6 @@ from core.workflow.variable_prefixes import ( ENVIRONMENT_VARIABLE_NODE_ID, ) from core.workflow.workflow_entry import WorkflowEntry -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.file.enums import FileType -from graphon.file.models import File, FileTransferMethod -from graphon.nodes.code.code_node import CodeNode -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.runtime import VariablePool -from graphon.variables.variables import StringVariable @pytest.fixture(autouse=True) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index 456ab5da41..879c0bb721 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -4,18 +4,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch, sentinel import pytest - -from core.app.apps.exc import GenerateTaskStoppedError -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.model_manager import ModelInstance -from core.workflow import workflow_entry -from core.workflow.system_variables import default_system_variables from graphon.entities.base_node_data import BaseNodeData from graphon.entities.graph_config import NodeConfigDictAdapter from graphon.enums import NodeType, WorkflowNodeExecutionStatus from graphon.errors import WorkflowNodeRunFailedError -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File +from graphon.file import File, FileTransferMethod, FileType from graphon.graph import Graph from graphon.graph_events import GraphRunFailedEvent from graphon.model_runtime.entities.llm_entities import LLMUsage @@ -24,6 +17,12 @@ from graphon.nodes import BuiltinNodeTypes from graphon.nodes.base.node import Node from graphon.runtime import ChildGraphNotFoundError, VariablePool from graphon.variables.variables import StringVariable + +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.model_manager import ModelInstance +from core.workflow import workflow_entry +from core.workflow.system_variables import default_system_variables from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py index b3ecfe4bc9..4b2f98aeff 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py @@ -2,10 +2,11 @@ from unittest.mock import MagicMock, patch +from graphon.graph_engine.command_channels import RedisChannel +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.workflow_entry import WorkflowEntry -from graphon.graph_engine.command_channels.redis_channel import RedisChannel -from graphon.runtime import GraphRuntimeState, VariablePool class TestWorkflowEntryRedisChannel: diff --git a/api/tests/unit_tests/core/workflow/utils/test_condition.py b/api/tests/unit_tests/core/workflow/utils/test_condition.py deleted file mode 100644 index f4c86aa77a..0000000000 --- a/api/tests/unit_tests/core/workflow/utils/test_condition.py +++ /dev/null @@ -1,52 +0,0 @@ -from graphon.runtime import VariablePool -from graphon.utils.condition.entities import Condition -from graphon.utils.condition.processor import ConditionProcessor - - -def test_number_formatting(): - condition_processor = ConditionProcessor() - variable_pool = VariablePool() - variable_pool.add(["test_node_id", "zone"], 0) - variable_pool.add(["test_node_id", "one"], 1) - variable_pool.add(["test_node_id", "one_one"], 1.1) - # 0 <= 0.95 - assert ( - condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=[Condition(variable_selector=["test_node_id", "zone"], comparison_operator="≤", value="0.95")], - operator="or", - ).final_result - == True - ) - - # 1 >= 0.95 - assert ( - condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=[Condition(variable_selector=["test_node_id", "one"], comparison_operator="≥", value="0.95")], - operator="or", - ).final_result - == True - ) - - # 1.1 >= 0.95 - assert ( - condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=[ - Condition(variable_selector=["test_node_id", "one_one"], comparison_operator="≥", value="0.95") - ], - operator="or", - ).final_result - == True - ) - - # 1.1 > 0 - assert ( - condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=[Condition(variable_selector=["test_node_id", "one_one"], comparison_operator=">", value="0")], - operator="or", - ).final_result - == True - ) diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py deleted file mode 100644 index 009c860f16..0000000000 --- a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py +++ /dev/null @@ -1,48 +0,0 @@ -import dataclasses - -from graphon.nodes.base import variable_template_parser -from graphon.nodes.base.entities import VariableSelector - - -def test_extract_selectors_from_template(): - template = ( - "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." - ) - selectors = variable_template_parser.extract_selectors_from_template(template) - assert selectors == [ - VariableSelector(variable="#sys.user_id#", value_selector=["sys", "user_id"]), - VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]), - VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]), - ] - - -def test_invalid_references(): - @dataclasses.dataclass - class TestCase: - name: str - template: str - - cases = [ - TestCase( - name="lack of closing brace", - template="Hello, {{#sys.user_id#", - ), - TestCase( - name="lack of opening brace", - template="Hello, #sys.user_id#}}", - ), - TestCase( - name="lack selector name", - template="Hello, {{#sys#}}", - ), - TestCase( - name="empty node name part", - template="Hello, {{#.user_id#}}", - ), - ] - for idx, c in enumerate(cases, 1): - fail_msg = f"Test case {c.name} failed, index={idx}" - selectors = variable_template_parser.extract_selectors_from_template(c.template) - assert selectors == [], fail_msg - parser = variable_template_parser.VariableTemplateParser(c.template) - assert parser.extract_variable_selectors() == [], fail_msg diff --git a/api/graphon/model_runtime/callbacks/__init__.py b/api/tests/unit_tests/enterprise/telemetry/__init__.py similarity index 100% rename from api/graphon/model_runtime/callbacks/__init__.py rename to api/tests/unit_tests/enterprise/telemetry/__init__.py diff --git a/api/tests/unit_tests/enterprise/telemetry/test_contracts.py b/api/tests/unit_tests/enterprise/telemetry/test_contracts.py new file mode 100644 index 0000000000..7453525bfc --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_contracts.py @@ -0,0 +1,230 @@ +"""Unit tests for telemetry gateway contracts.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from core.telemetry.gateway import CASE_ROUTING +from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase, TelemetryEnvelope + + +class TestTelemetryCase: + """Tests for TelemetryCase enum.""" + + def test_all_cases_defined(self) -> None: + """Verify all 14 telemetry cases are defined.""" + expected_cases = { + "WORKFLOW_RUN", + "NODE_EXECUTION", + "DRAFT_NODE_EXECUTION", + "MESSAGE_RUN", + "TOOL_EXECUTION", + "MODERATION_CHECK", + "SUGGESTED_QUESTION", + "DATASET_RETRIEVAL", + "GENERATE_NAME", + "PROMPT_GENERATION", + "APP_CREATED", + "APP_UPDATED", + "APP_DELETED", + "FEEDBACK_CREATED", + } + actual_cases = {case.name for case in TelemetryCase} + assert actual_cases == expected_cases + + def test_case_values(self) -> None: + """Verify case enum values are correct.""" + assert TelemetryCase.WORKFLOW_RUN.value == "workflow_run" + assert TelemetryCase.NODE_EXECUTION.value == "node_execution" + assert TelemetryCase.DRAFT_NODE_EXECUTION.value == "draft_node_execution" + assert TelemetryCase.MESSAGE_RUN.value == "message_run" + assert TelemetryCase.TOOL_EXECUTION.value == "tool_execution" + assert TelemetryCase.MODERATION_CHECK.value == "moderation_check" + assert TelemetryCase.SUGGESTED_QUESTION.value == "suggested_question" + assert TelemetryCase.DATASET_RETRIEVAL.value == "dataset_retrieval" + assert TelemetryCase.GENERATE_NAME.value == "generate_name" + assert TelemetryCase.PROMPT_GENERATION.value == "prompt_generation" + assert TelemetryCase.APP_CREATED.value == "app_created" + assert TelemetryCase.APP_UPDATED.value == "app_updated" + assert TelemetryCase.APP_DELETED.value == "app_deleted" + assert TelemetryCase.FEEDBACK_CREATED.value == "feedback_created" + + +class TestCaseRoute: + """Tests for CaseRoute model.""" + + def test_valid_trace_route(self) -> None: + """Verify valid trace route creation.""" + route = CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True) + assert route.signal_type == SignalType.TRACE + assert route.ce_eligible is True + + def test_valid_metric_log_route(self) -> None: + """Verify valid metric_log route creation.""" + route = CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False) + assert route.signal_type == SignalType.METRIC_LOG + assert route.ce_eligible is False + + def test_invalid_signal_type(self) -> None: + """Verify invalid signal_type is rejected.""" + with pytest.raises(ValidationError): + CaseRoute(signal_type="invalid", ce_eligible=True) + + +class TestTelemetryEnvelope: + """Tests for TelemetryEnvelope model.""" + + def test_valid_envelope_minimal(self) -> None: + """Verify valid minimal envelope creation.""" + envelope = TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + event_id="event-456", + payload={"key": "value"}, + ) + assert envelope.case == TelemetryCase.WORKFLOW_RUN + assert envelope.tenant_id == "tenant-123" + assert envelope.event_id == "event-456" + assert envelope.payload == {"key": "value"} + assert envelope.metadata is None + + def test_valid_envelope_full(self) -> None: + """Verify valid envelope with all fields.""" + metadata = {"payload_ref": "telemetry/tenant-789/event-012.json"} + envelope = TelemetryEnvelope( + case=TelemetryCase.MESSAGE_RUN, + tenant_id="tenant-789", + event_id="event-012", + payload={"message": "hello"}, + metadata=metadata, + ) + assert envelope.case == TelemetryCase.MESSAGE_RUN + assert envelope.tenant_id == "tenant-789" + assert envelope.event_id == "event-012" + assert envelope.payload == {"message": "hello"} + assert envelope.metadata == metadata + + def test_missing_required_case(self) -> None: + """Verify missing case field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + tenant_id="tenant-123", + event_id="event-456", + payload={"key": "value"}, + ) + + def test_missing_required_tenant_id(self) -> None: + """Verify missing tenant_id field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + event_id="event-456", + payload={"key": "value"}, + ) + + def test_missing_required_event_id(self) -> None: + """Verify missing event_id field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + payload={"key": "value"}, + ) + + def test_missing_required_payload(self) -> None: + """Verify missing payload field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + event_id="event-456", + ) + + def test_metadata_none(self) -> None: + """Verify metadata can be None.""" + envelope = TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + event_id="event-456", + payload={"key": "value"}, + metadata=None, + ) + assert envelope.metadata is None + + +class TestCaseRouting: + """Tests for CASE_ROUTING table.""" + + def test_all_cases_routed(self) -> None: + """Verify all 14 cases have routing entries.""" + assert len(CASE_ROUTING) == 14 + for case in TelemetryCase: + assert case in CASE_ROUTING + + def test_trace_ce_eligible_cases(self) -> None: + """Verify trace cases with CE eligibility.""" + ce_eligible_trace_cases = { + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + } + for case in ce_eligible_trace_cases: + route = CASE_ROUTING[case] + assert route.signal_type == SignalType.TRACE + assert route.ce_eligible is True + + def test_trace_enterprise_only_cases(self) -> None: + """Verify trace cases that are enterprise-only.""" + enterprise_only_trace_cases = { + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + } + for case in enterprise_only_trace_cases: + route = CASE_ROUTING[case] + assert route.signal_type == SignalType.TRACE + assert route.ce_eligible is False + + def test_metric_log_cases(self) -> None: + """Verify metric/log-only cases.""" + metric_log_cases = { + TelemetryCase.APP_CREATED, + TelemetryCase.APP_UPDATED, + TelemetryCase.APP_DELETED, + TelemetryCase.FEEDBACK_CREATED, + } + for case in metric_log_cases: + route = CASE_ROUTING[case] + assert route.signal_type == SignalType.METRIC_LOG + assert route.ce_eligible is False + + def test_routing_table_completeness(self) -> None: + """Verify routing table covers all cases with correct types.""" + trace_cases = { + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + TelemetryCase.TOOL_EXECUTION, + TelemetryCase.MODERATION_CHECK, + TelemetryCase.SUGGESTED_QUESTION, + TelemetryCase.DATASET_RETRIEVAL, + TelemetryCase.GENERATE_NAME, + } + metric_log_cases = { + TelemetryCase.APP_CREATED, + TelemetryCase.APP_UPDATED, + TelemetryCase.APP_DELETED, + TelemetryCase.FEEDBACK_CREATED, + } + + all_cases = trace_cases | metric_log_cases + assert len(all_cases) == 14 + assert all_cases == set(TelemetryCase) + + for case in trace_cases: + assert CASE_ROUTING[case].signal_type == SignalType.TRACE + + for case in metric_log_cases: + assert CASE_ROUTING[case].signal_type == SignalType.METRIC_LOG diff --git a/api/tests/unit_tests/enterprise/telemetry/test_draft_trace.py b/api/tests/unit_tests/enterprise/telemetry/test_draft_trace.py new file mode 100644 index 0000000000..c8c8de8595 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_draft_trace.py @@ -0,0 +1,519 @@ +"""Unit tests for enterprise/telemetry/draft_trace.py.""" + +from __future__ import annotations + +from datetime import UTC, datetime +from unittest.mock import MagicMock, patch + +from graphon.enums import WorkflowNodeExecutionMetadataKey + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_execution(**overrides) -> MagicMock: + """Return a minimal WorkflowNodeExecutionModel mock.""" + execution = MagicMock() + execution.tenant_id = overrides.get("tenant_id", "tenant-1") + execution.app_id = overrides.get("app_id", "app-1") + execution.workflow_id = overrides.get("workflow_id", "wf-1") + execution.id = overrides.get("id", "exec-1") + execution.node_id = overrides.get("node_id", "node-1") + execution.node_type = overrides.get("node_type", "llm") + execution.title = overrides.get("title", "My LLM Node") + execution.status = overrides.get("status", "succeeded") + execution.error = overrides.get("error") + execution.elapsed_time = overrides.get("elapsed_time", 1.5) + execution.index = overrides.get("index", 1) + execution.predecessor_node_id = overrides.get("predecessor_node_id") + execution.created_at = overrides.get("created_at", datetime(2024, 1, 1, tzinfo=UTC)) + execution.finished_at = overrides.get("finished_at", datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)) + execution.workflow_run_id = overrides.get("workflow_run_id", "run-1") + execution.inputs_dict = overrides.get("inputs_dict", {"prompt": "hello"}) + execution.outputs_dict = overrides.get("outputs_dict", {"answer": "world"}) + execution.process_data_dict = overrides.get("process_data_dict", {}) + execution.execution_metadata_dict = overrides.get("execution_metadata_dict", {}) + return execution + + +# --------------------------------------------------------------------------- +# _build_node_execution_data +# --------------------------------------------------------------------------- + + +class TestBuildNodeExecutionData: + def test_basic_fields_populated(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution() + result = _build_node_execution_data( + execution=execution, + outputs=None, + workflow_execution_id="run-override", + ) + + assert result["workflow_id"] == "wf-1" + assert result["tenant_id"] == "tenant-1" + assert result["app_id"] == "app-1" + assert result["node_execution_id"] == "exec-1" + assert result["node_id"] == "node-1" + assert result["node_type"] == "llm" + assert result["title"] == "My LLM Node" + assert result["status"] == "succeeded" + assert result["error"] is None + assert result["elapsed_time"] == 1.5 + assert result["index"] == 1 + + def test_workflow_execution_id_prefers_parameter(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(workflow_run_id="run-from-model") + result = _build_node_execution_data( + execution=execution, + outputs=None, + workflow_execution_id="explicit-run", + ) + assert result["workflow_execution_id"] == "explicit-run" + + def test_workflow_execution_id_falls_back_to_run_id(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(workflow_run_id="run-from-model") + result = _build_node_execution_data( + execution=execution, + outputs=None, + workflow_execution_id=None, + ) + assert result["workflow_execution_id"] == "run-from-model" + + def test_workflow_execution_id_falls_back_to_execution_id(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(workflow_run_id=None, id="exec-fallback") + result = _build_node_execution_data( + execution=execution, + outputs=None, + workflow_execution_id=None, + ) + assert result["workflow_execution_id"] == "exec-fallback" + + def test_outputs_param_overrides_execution_outputs(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(outputs_dict={"from_model": True}) + result = _build_node_execution_data( + execution=execution, + outputs={"from_param": True}, + workflow_execution_id=None, + ) + assert result["node_outputs"] == {"from_param": True} + + def test_outputs_none_uses_execution_outputs_dict(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(outputs_dict={"from_model": True}) + result = _build_node_execution_data( + execution=execution, + outputs=None, + workflow_execution_id=None, + ) + assert result["node_outputs"] == {"from_model": True} + + def test_metadata_token_fields_default_to_zero(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(execution_metadata_dict={}) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["total_tokens"] == 0 + assert result["total_price"] == 0.0 + assert result["currency"] is None + + def test_metadata_token_fields_populated_from_metadata(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + metadata = { + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 200, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 0.05, + WorkflowNodeExecutionMetadataKey.CURRENCY: "USD", + } + execution = _make_execution(execution_metadata_dict=metadata) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["total_tokens"] == 200 + assert result["total_price"] == 0.05 + assert result["currency"] == "USD" + + def test_tool_name_extracted_from_tool_info_dict(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + metadata = { + WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"tool_name": "web_search"}, + } + execution = _make_execution(execution_metadata_dict=metadata) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["tool_name"] == "web_search" + + def test_tool_name_is_none_when_tool_info_not_dict(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: "not-a-dict"} + execution = _make_execution(execution_metadata_dict=metadata) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["tool_name"] is None + + def test_tool_name_is_none_when_tool_info_absent(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(execution_metadata_dict={}) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["tool_name"] is None + + def test_iteration_and_loop_fields(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + metadata = { + WorkflowNodeExecutionMetadataKey.ITERATION_ID: "iter-1", + WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: 3, + WorkflowNodeExecutionMetadataKey.LOOP_ID: "loop-1", + WorkflowNodeExecutionMetadataKey.LOOP_INDEX: 2, + WorkflowNodeExecutionMetadataKey.PARALLEL_ID: "par-1", + } + execution = _make_execution(execution_metadata_dict=metadata) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["iteration_id"] == "iter-1" + assert result["iteration_index"] == 3 + assert result["loop_id"] == "loop-1" + assert result["loop_index"] == 2 + assert result["parallel_id"] == "par-1" + + def test_node_inputs_and_process_data_included(self) -> None: + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution( + inputs_dict={"q": "test"}, + process_data_dict={"step": 1}, + ) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["node_inputs"] == {"q": "test"} + assert result["process_data"] == {"step": 1} + + +# --------------------------------------------------------------------------- +# enqueue_draft_node_execution_trace +# --------------------------------------------------------------------------- + + +class TestEnqueueDraftNodeExecutionTrace: + @patch("enterprise.telemetry.draft_trace.telemetry_emit") + def test_emits_telemetry_event(self, mock_emit: MagicMock) -> None: + from core.telemetry import TelemetryEvent, TraceTaskName + from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace + + execution = _make_execution() + enqueue_draft_node_execution_trace( + execution=execution, + outputs={"result": "ok"}, + workflow_execution_id="run-x", + user_id="user-1", + ) + + mock_emit.assert_called_once() + event: TelemetryEvent = mock_emit.call_args[0][0] + assert event.name == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + assert event.context.tenant_id == "tenant-1" + assert event.context.user_id == "user-1" + assert event.context.app_id == "app-1" + + @patch("enterprise.telemetry.draft_trace.telemetry_emit") + def test_payload_contains_node_execution_data(self, mock_emit: MagicMock) -> None: + from core.telemetry import TelemetryEvent + from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace + + execution = _make_execution() + enqueue_draft_node_execution_trace( + execution=execution, + outputs=None, + workflow_execution_id=None, + user_id="user-2", + ) + + event: TelemetryEvent = mock_emit.call_args[0][0] + node_data = event.payload["node_execution_data"] + assert node_data["workflow_id"] == "wf-1" + assert node_data["node_type"] == "llm" + assert node_data["status"] == "succeeded" + + @patch("enterprise.telemetry.draft_trace.telemetry_emit") + def test_outputs_forwarded_to_build(self, mock_emit: MagicMock) -> None: + from core.telemetry import TelemetryEvent + from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace + + execution = _make_execution(outputs_dict={"default": True}) + enqueue_draft_node_execution_trace( + execution=execution, + outputs={"explicit": True}, + workflow_execution_id=None, + user_id="user-3", + ) + + event: TelemetryEvent = mock_emit.call_args[0][0] + assert event.payload["node_execution_data"]["node_outputs"] == {"explicit": True} + + @patch("enterprise.telemetry.draft_trace.telemetry_emit") + def test_none_outputs_uses_execution_outputs(self, mock_emit: MagicMock) -> None: + from core.telemetry import TelemetryEvent + from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace + + execution = _make_execution(outputs_dict={"from_model": "yes"}) + enqueue_draft_node_execution_trace( + execution=execution, + outputs=None, + workflow_execution_id=None, + user_id="user-4", + ) + + event: TelemetryEvent = mock_emit.call_args[0][0] + assert event.payload["node_execution_data"]["node_outputs"] == {"from_model": "yes"} + + +# --------------------------------------------------------------------------- +# End-to-end token/model data flow: _build_node_execution_data → +# ops_trace_manager.draft_node_execution_trace → DraftNodeExecutionTrace +# --------------------------------------------------------------------------- + + +def _make_llm_execution() -> MagicMock: + """Return a WorkflowNodeExecutionModel mock that mimics a real LLM node. + + The field values match what graphon/nodes/llm/node.py produces: + - process_data_dict contains model_provider, model_name, and usage + - outputs_dict contains usage with prompt/completion breakdown + - execution_metadata_dict contains total_tokens/total_price/currency + """ + return _make_execution( + tenant_id="tenant-flow", + app_id="app-flow", + workflow_id="wf-flow", + id="exec-flow", + node_id="node-llm", + node_type="llm", + title="GPT-4o Node", + status="succeeded", + elapsed_time=2.3, + workflow_run_id=None, + process_data_dict={ + "model_mode": "chat", + "model_provider": "openai", + "model_name": "gpt-4o", + "prompts": [{"role": "user", "text": "hello"}], + "usage": { + "prompt_tokens": 50, + "prompt_unit_price": 0.00001, + "prompt_price_unit": 0.001, + "prompt_price": 0.0005, + "completion_tokens": 30, + "completion_unit_price": 0.00003, + "completion_price_unit": 0.001, + "completion_price": 0.0009, + "total_tokens": 80, + "total_price": 0.0014, + "currency": "USD", + "latency": 2.3, + }, + "finish_reason": "stop", + }, + outputs_dict={ + "text": "world", + "usage": { + "prompt_tokens": 50, + "completion_tokens": 30, + "total_tokens": 80, + "total_price": 0.0014, + "currency": "USD", + }, + "finish_reason": "stop", + }, + execution_metadata_dict={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 80, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 0.0014, + WorkflowNodeExecutionMetadataKey.CURRENCY: "USD", + }, + ) + + +class TestDraftTraceTokenDataFlow: + """End-to-end test: verify all token and model fields survive from + _build_node_execution_data through ops_trace_manager.draft_node_execution_trace + to the DraftNodeExecutionTrace that enterprise_trace.py consumes. + """ + + def test_all_token_and_model_fields_reach_trace_info(self) -> None: + """Simulate the full draft trace data flow for an LLM node and + assert every token/model field that enterprise_trace._emit_node_execution_trace + reads is populated correctly on the resulting DraftNodeExecutionTrace.""" + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_llm_execution() + node_data = _build_node_execution_data( + execution=execution, + outputs=None, + workflow_execution_id="run-flow", + ) + + # Simulate what ops_trace_manager.draft_node_execution_trace does: + # it calls node_execution_trace(node_execution_data=node_data) which + # reads top-level keys from node_data. Verify all expected keys exist. + expected_keys = { + # Token fields — read by enterprise_trace._emit_node_execution_trace + "total_tokens", + "total_price", + "currency", + "prompt_tokens", + "completion_tokens", + # Model fields — read for span attrs and metric labels + "model_provider", + "model_name", + # Node identity — read for span attrs + "node_type", + "node_execution_id", + "node_id", + "title", + "status", + "error", + "elapsed_time", + # Workflow context + "workflow_id", + "workflow_execution_id", + "tenant_id", + "app_id", + # Structure fields + "index", + "predecessor_node_id", + "iteration_id", + "iteration_index", + "loop_id", + "loop_index", + "parallel_id", + # Tool field + "tool_name", + # Content fields + "node_inputs", + "node_outputs", + "process_data", + # Timestamps + "created_at", + "finished_at", + } + assert set(node_data.keys()) == expected_keys + + # Verify token/model values are correct (not None/zero when data exists) + assert node_data["total_tokens"] == 80 + assert node_data["total_price"] == 0.0014 + assert node_data["currency"] == "USD" + assert node_data["prompt_tokens"] == 50 + assert node_data["completion_tokens"] == 30 + assert node_data["model_provider"] == "openai" + assert node_data["model_name"] == "gpt-4o" + assert node_data["node_type"] == "llm" + + def test_non_llm_node_has_none_for_model_and_token_breakdown(self) -> None: + """For non-LLM nodes (e.g. code, IF), model and token breakdown + should be None, but total_tokens from metadata should still work.""" + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution( + node_type="code", + process_data_dict={"code": "print('hi')"}, + outputs_dict={"result": "hi"}, + execution_metadata_dict={}, + ) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["model_provider"] is None + assert result["model_name"] is None + assert result["prompt_tokens"] is None + assert result["completion_tokens"] is None + assert result["total_tokens"] == 0 + + def test_none_process_data_and_none_outputs(self) -> None: + """Both process_data_dict and outputs_dict are None — exercises + the `or {}` fallback and isinstance guard together.""" + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_execution(process_data_dict=None, outputs_dict=None) + result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None) + + assert result["model_provider"] is None + assert result["model_name"] is None + assert result["prompt_tokens"] is None + assert result["completion_tokens"] is None + + def test_node_data_feeds_into_draft_node_execution_trace(self) -> None: + """Verify the node_data dict can be consumed by + ops_trace_manager.draft_node_execution_trace without error and + produces a DraftNodeExecutionTrace with correct token/model fields.""" + + from enterprise.telemetry.draft_trace import _build_node_execution_data + + execution = _make_llm_execution() + node_data = _build_node_execution_data( + execution=execution, + outputs=None, + workflow_execution_id="run-e2e", + ) + + # Directly construct DraftNodeExecutionTrace the way + # ops_trace_manager.node_execution_trace does (lines 1315-1350), + # skipping DB lookups by providing minimal metadata. + from core.ops.entities.trace_entity import DraftNodeExecutionTrace + + trace_info = DraftNodeExecutionTrace( + workflow_id=node_data.get("workflow_id", ""), + workflow_run_id=node_data.get("workflow_execution_id", ""), + tenant_id=node_data.get("tenant_id", ""), + node_execution_id=node_data.get("node_execution_id", ""), + node_id=node_data.get("node_id", ""), + node_type=node_data.get("node_type", ""), + title=node_data.get("title", ""), + status=node_data.get("status", ""), + error=node_data.get("error"), + elapsed_time=node_data.get("elapsed_time", 0.0), + index=node_data.get("index", 0), + predecessor_node_id=node_data.get("predecessor_node_id"), + total_tokens=node_data.get("total_tokens", 0), + total_price=node_data.get("total_price", 0.0), + currency=node_data.get("currency"), + model_provider=node_data.get("model_provider"), + model_name=node_data.get("model_name"), + prompt_tokens=node_data.get("prompt_tokens"), + completion_tokens=node_data.get("completion_tokens"), + tool_name=node_data.get("tool_name"), + iteration_id=node_data.get("iteration_id"), + iteration_index=node_data.get("iteration_index"), + loop_id=node_data.get("loop_id"), + loop_index=node_data.get("loop_index"), + parallel_id=node_data.get("parallel_id"), + node_inputs=node_data.get("node_inputs"), + node_outputs=node_data.get("node_outputs"), + process_data=node_data.get("process_data"), + start_time=node_data.get("created_at"), + end_time=node_data.get("finished_at"), + metadata={}, + ) + + # These are the fields enterprise_trace._emit_node_execution_trace reads + assert trace_info.total_tokens == 80 + assert trace_info.prompt_tokens == 50 + assert trace_info.completion_tokens == 30 + assert trace_info.model_provider == "openai" + assert trace_info.model_name == "gpt-4o" + assert trace_info.node_type == "llm" + assert trace_info.total_price == 0.0014 + assert trace_info.currency == "USD" diff --git a/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py b/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py new file mode 100644 index 0000000000..bb1f78b80c --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py @@ -0,0 +1,1327 @@ +"""Unit tests for EnterpriseOtelTrace.""" + +from __future__ import annotations + +import json +from datetime import UTC, datetime +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + PromptGenerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowNodeTraceInfo, + WorkflowTraceInfo, +) +from enterprise.telemetry.entities import ( + EnterpriseTelemetryCounter, + EnterpriseTelemetryEvent, + EnterpriseTelemetryHistogram, + EnterpriseTelemetrySpan, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_exporter(): + exporter = MagicMock() + exporter.include_content = True + return exporter + + +@pytest.fixture +def trace_handler(mock_exporter): + with patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter", return_value=mock_exporter): + from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace + + handler = EnterpriseOtelTrace() + return handler + + +# --------------------------------------------------------------------------- +# Factory helpers +# --------------------------------------------------------------------------- + +_T0 = datetime(2024, 1, 10, 12, 0, 0, tzinfo=UTC) +_T1 = datetime(2024, 1, 10, 12, 0, 5, tzinfo=UTC) + + +def make_workflow_info(**overrides) -> WorkflowTraceInfo: + defaults: dict = { + "workflow_id": "wf-001", + "tenant_id": "tenant-abc", + "workflow_run_id": "run-001", + "workflow_run_elapsed_time": 5.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"query": "hello"}, + "workflow_run_outputs": {"answer": "world"}, + "workflow_run_version": "1", + "total_tokens": 100, + "prompt_tokens": 60, + "completion_tokens": 40, + "file_list": [], + "query": "hello", + "start_time": _T0, + "end_time": _T1, + "metadata": { + "app_id": "app-001", + "tenant_id": "tenant-abc", + "app_name": "MyApp", + "workspace_name": "WS", + "triggered_from": "api", + }, + } + defaults.update(overrides) + return WorkflowTraceInfo(**defaults) + + +def make_node_info(**overrides) -> WorkflowNodeTraceInfo: + defaults: dict = { + "workflow_id": "wf-001", + "workflow_run_id": "run-001", + "tenant_id": "tenant-abc", + "node_execution_id": "ne-001", + "node_id": "node-001", + "node_type": "llm", + "title": "LLM Node", + "status": "succeeded", + "elapsed_time": 2.5, + "index": 1, + "total_tokens": 80, + "prompt_tokens": 50, + "completion_tokens": 30, + "model_provider": "openai", + "model_name": "gpt-4", + "start_time": _T0, + "end_time": _T1, + "metadata": { + "app_id": "app-001", + "tenant_id": "tenant-abc", + "app_name": "MyApp", + }, + } + defaults.update(overrides) + return WorkflowNodeTraceInfo(**defaults) + + +def make_draft_node_info(**overrides) -> DraftNodeExecutionTrace: + defaults: dict = { + "workflow_id": "wf-001", + "workflow_run_id": "run-draft-001", + "tenant_id": "tenant-abc", + "node_execution_id": "ne-draft-001", + "node_id": "node-001", + "node_type": "llm", + "title": "Draft LLM", + "status": "succeeded", + "elapsed_time": 1.2, + "index": 0, + "total_tokens": 50, + "start_time": _T0, + "end_time": _T1, + "metadata": {"app_id": "app-001", "tenant_id": "tenant-abc"}, + } + defaults.update(overrides) + return DraftNodeExecutionTrace(**defaults) + + +def make_message_info(**overrides) -> MessageTraceInfo: + defaults: dict = { + "message_id": "msg-001", + "conversation_model": "gpt-4", + "message_tokens": 40, + "answer_tokens": 60, + "total_tokens": 100, + "conversation_mode": "chat", + "start_time": _T0, + "end_time": _T1, + "inputs": "user input", + "outputs": "assistant output", + "metadata": { + "app_id": "app-001", + "tenant_id": "tenant-abc", + "from_source": "api", + "ls_provider": "openai", + "ls_model_name": "gpt-4", + "status": "succeeded", + }, + } + defaults.update(overrides) + return MessageTraceInfo(**defaults) + + +def make_tool_info(**overrides) -> ToolTraceInfo: + defaults: dict = { + "message_id": "msg-001", + "tool_name": "web_search", + "tool_inputs": {"query": "test"}, + "tool_outputs": "search results", + "tool_config": {"max_results": 5}, + "tool_parameters": {"verbose": True}, + "time_cost": 1.5, + "metadata": {"app_id": "app-001", "tenant_id": "tenant-abc"}, + } + defaults.update(overrides) + return ToolTraceInfo(**defaults) + + +def make_moderation_info(**overrides) -> ModerationTraceInfo: + defaults: dict = { + "message_id": "msg-001", + "flagged": False, + "action": "pass", + "preset_response": "", + "query": "is this ok?", + "metadata": {"app_id": "app-001", "tenant_id": "tenant-abc"}, + } + defaults.update(overrides) + return ModerationTraceInfo(**defaults) + + +def make_suggested_question_info(**overrides) -> SuggestedQuestionTraceInfo: + defaults: dict = { + "message_id": "msg-001", + "total_tokens": 30, + "suggested_question": ["Question A?", "Question B?"], + "level": "info", + "status": "succeeded", + "model_provider": "openai", + "model_id": "gpt-3.5-turbo", + "start_time": _T0, + "end_time": _T1, + "metadata": {"app_id": "app-001", "tenant_id": "tenant-abc"}, + } + defaults.update(overrides) + return SuggestedQuestionTraceInfo(**defaults) + + +def make_dataset_retrieval_info(**overrides) -> DatasetRetrievalTraceInfo: + defaults: dict = { + "message_id": "msg-001", + "documents": [ + { + "metadata": { + "dataset_id": "ds-001", + "dataset_name": "MyDataset", + "document_id": "doc-001", + "segment_id": "seg-001", + "score": 0.95, + } + } + ], + "inputs": "search query", + "metadata": { + "app_id": "app-001", + "tenant_id": "tenant-abc", + "embedding_models": { + "ds-001": { + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-3-small", + } + }, + }, + } + defaults.update(overrides) + return DatasetRetrievalTraceInfo(**defaults) + + +def make_generate_name_info(**overrides) -> GenerateNameTraceInfo: + defaults: dict = { + "message_id": "msg-001", + "tenant_id": "tenant-abc", + "conversation_id": "conv-001", + "inputs": "some content", + "outputs": "My Conversation", + "start_time": _T0, + "end_time": _T1, + "metadata": {"app_id": "app-001", "tenant_id": "tenant-abc"}, + } + defaults.update(overrides) + return GenerateNameTraceInfo(**defaults) + + +def make_prompt_generation_info(**overrides) -> PromptGenerationTraceInfo: + defaults: dict = { + "tenant_id": "tenant-abc", + "user_id": "user-001", + "app_id": "app-001", + "operation_type": "rule_generate", + "instruction": "Generate a helpful prompt", + "prompt_tokens": 50, + "completion_tokens": 100, + "total_tokens": 150, + "model_provider": "openai", + "model_name": "gpt-4", + "latency": 3.2, + "metadata": {"app_id": "app-001", "tenant_id": "tenant-abc"}, + } + defaults.update(overrides) + return PromptGenerationTraceInfo(**defaults) + + +# --------------------------------------------------------------------------- +# Constructor +# --------------------------------------------------------------------------- + + +def test_init_raises_when_exporter_is_none(): + with patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter", return_value=None): + from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace + + with pytest.raises(RuntimeError, match="exporter is not initialized"): + EnterpriseOtelTrace() + + +def test_init_succeeds_with_valid_exporter(mock_exporter): + with patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter", return_value=mock_exporter): + from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace + + handler = EnterpriseOtelTrace() + assert handler._exporter is mock_exporter + + +# --------------------------------------------------------------------------- +# Helper methods +# --------------------------------------------------------------------------- + + +class TestSafePayloadValue: + def test_string_passthrough(self, trace_handler): + assert trace_handler._safe_payload_value("hello") == "hello" + + def test_dict_passthrough(self, trace_handler): + d = {"key": "val"} + assert trace_handler._safe_payload_value(d) == d + + def test_list_passthrough(self, trace_handler): + lst = [1, 2, 3] + assert trace_handler._safe_payload_value(lst) == lst + + def test_none_returns_none(self, trace_handler): + assert trace_handler._safe_payload_value(None) is None + + def test_int_returns_none(self, trace_handler): + assert trace_handler._safe_payload_value(42) is None + + def test_bool_returns_none(self, trace_handler): + assert trace_handler._safe_payload_value(True) is None + + +class TestMaybeJson: + def test_none_returns_none(self, trace_handler): + assert trace_handler._maybe_json(None) is None + + def test_string_passthrough(self, trace_handler): + assert trace_handler._maybe_json("hello") == "hello" + + def test_dict_serialised(self, trace_handler): + result = trace_handler._maybe_json({"a": 1}) + assert result == json.dumps({"a": 1}) + + def test_list_serialised(self, trace_handler): + result = trace_handler._maybe_json([1, 2]) + assert result == "[1, 2]" + + def test_non_serialisable_falls_back_to_str(self, trace_handler): + class Unserializable: + def __repr__(self): + return "Unserializable()" + + obj = Unserializable() + result = trace_handler._maybe_json(obj) + assert isinstance(result, str) + + +class TestContentOrRef: + def test_returns_content_when_include_content_true(self, trace_handler, mock_exporter): + mock_exporter.include_content = True + result = trace_handler._content_or_ref("actual content", "ref:x=1") + assert result == "actual content" + + def test_returns_ref_when_include_content_false(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + result = trace_handler._content_or_ref("actual content", "ref:x=1") + assert result == "ref:x=1" + + def test_dict_serialised_when_include_content_true(self, trace_handler, mock_exporter): + mock_exporter.include_content = True + result = trace_handler._content_or_ref({"key": "val"}, "ref:x=1") + assert result == json.dumps({"key": "val"}) + + def test_none_returns_none_when_include_content_true(self, trace_handler, mock_exporter): + mock_exporter.include_content = True + result = trace_handler._content_or_ref(None, "ref:x=1") + assert result is None + + +# --------------------------------------------------------------------------- +# trace() dispatcher +# --------------------------------------------------------------------------- + + +class TestTraceDispatcher: + def test_dispatches_workflow_trace(self, trace_handler): + with patch.object(trace_handler, "_workflow_trace") as mock_method: + info = make_workflow_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_message_trace(self, trace_handler): + with patch.object(trace_handler, "_message_trace") as mock_method: + info = make_message_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_tool_trace(self, trace_handler): + with patch.object(trace_handler, "_tool_trace") as mock_method: + info = make_tool_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_draft_node_execution_trace(self, trace_handler): + with patch.object(trace_handler, "_draft_node_execution_trace") as mock_method: + info = make_draft_node_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_node_execution_trace(self, trace_handler): + with patch.object(trace_handler, "_node_execution_trace") as mock_method: + info = make_node_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_moderation_trace(self, trace_handler): + with patch.object(trace_handler, "_moderation_trace") as mock_method: + info = make_moderation_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_suggested_question_trace(self, trace_handler): + with patch.object(trace_handler, "_suggested_question_trace") as mock_method: + info = make_suggested_question_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_dataset_retrieval_trace(self, trace_handler): + with patch.object(trace_handler, "_dataset_retrieval_trace") as mock_method: + info = make_dataset_retrieval_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_generate_name_trace(self, trace_handler): + with patch.object(trace_handler, "_generate_name_trace") as mock_method: + info = make_generate_name_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_dispatches_prompt_generation_trace(self, trace_handler): + with patch.object(trace_handler, "_prompt_generation_trace") as mock_method: + info = make_prompt_generation_info() + trace_handler.trace(info) + mock_method.assert_called_once_with(info) + + def test_draft_node_dispatched_before_node(self, trace_handler): + """DraftNodeExecutionTrace is a subclass of WorkflowNodeTraceInfo; + it must be dispatched to _draft_node_execution_trace, not _node_execution_trace.""" + with ( + patch.object(trace_handler, "_draft_node_execution_trace") as mock_draft, + patch.object(trace_handler, "_node_execution_trace") as mock_node, + ): + info = make_draft_node_info() + trace_handler.trace(info) + mock_draft.assert_called_once_with(info) + mock_node.assert_not_called() + + +# --------------------------------------------------------------------------- +# _workflow_trace +# --------------------------------------------------------------------------- + + +class TestWorkflowTrace: + def test_emits_correct_span_attributes(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: + info = make_workflow_info() + trace_handler._workflow_trace(info) + + mock_exporter.export_span.assert_called_once() + span_call = mock_exporter.export_span.call_args + assert span_call[0][0] == EnterpriseTelemetrySpan.WORKFLOW_RUN + attrs = span_call[0][1] + assert attrs["dify.workflow.run_id"] == "run-001" + assert attrs["dify.workflow.id"] == "wf-001" + assert attrs["dify.tenant_id"] == "tenant-abc" + assert attrs["dify.workflow.status"] == "succeeded" + assert attrs["gen_ai.usage.total_tokens"] == 100 + + def test_span_timing_passed_correctly(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_workflow_info() + trace_handler._workflow_trace(info) + + span_call = mock_exporter.export_span.call_args + assert span_call[1]["start_time"] == _T0 + assert span_call[1]["end_time"] == _T1 + + def test_emits_companion_log_with_event_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: + trace_handler._workflow_trace(make_workflow_info()) + + mock_log.assert_called_once() + assert mock_log.call_args[1]["event_name"] == EnterpriseTelemetryEvent.WORKFLOW_RUN + assert mock_log.call_args[1]["tenant_id"] == "tenant-abc" + + def test_companion_log_includes_content_when_enabled(self, trace_handler, mock_exporter): + mock_exporter.include_content = True + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: + trace_handler._workflow_trace(make_workflow_info()) + + log_attrs = mock_log.call_args[1]["attributes"] + assert log_attrs["dify.workflow.inputs"] == json.dumps({"query": "hello"}) + assert log_attrs["dify.workflow.outputs"] == json.dumps({"answer": "world"}) + + def test_companion_log_uses_ref_when_content_disabled(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: + trace_handler._workflow_trace(make_workflow_info()) + + log_attrs = mock_log.call_args[1]["attributes"] + assert log_attrs["dify.workflow.inputs"].startswith("ref:workflow_run_id=") + assert log_attrs["dify.workflow.outputs"].startswith("ref:workflow_run_id=") + + def test_increments_token_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._workflow_trace(make_workflow_info()) + + token_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.TOKENS + ] + assert len(token_calls) == 1 + assert token_calls[0][0][1] == 100 + + def test_increments_input_and_output_token_counters(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._workflow_trace(make_workflow_info()) + + all_calls = mock_exporter.increment_counter.call_args_list + counter_names = [c[0][0] for c in all_calls] + assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names + assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names + + def test_no_input_token_counter_when_prompt_tokens_zero(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_workflow_info(prompt_tokens=0) + trace_handler._workflow_trace(info) + + all_calls = mock_exporter.increment_counter.call_args_list + counter_names = [c[0][0] for c in all_calls] + assert EnterpriseTelemetryCounter.INPUT_TOKENS not in counter_names + + def test_records_workflow_duration_histogram(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._workflow_trace(make_workflow_info()) + + mock_exporter.record_histogram.assert_called_once() + hist_call = mock_exporter.record_histogram.call_args + assert hist_call[0][0] == EnterpriseTelemetryHistogram.WORKFLOW_DURATION + assert hist_call[0][1] == pytest.approx(5.0) + + def test_duration_falls_back_to_elapsed_time_when_timestamps_missing(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_workflow_info(start_time=None, end_time=None, workflow_run_elapsed_time=7.3) + trace_handler._workflow_trace(info) + + hist_call = mock_exporter.record_histogram.call_args + assert hist_call[0][1] == pytest.approx(7.3) + + def test_duration_defaults_to_zero_when_no_timing(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_workflow_info(start_time=None, end_time=None, workflow_run_elapsed_time=0) + trace_handler._workflow_trace(info) + + hist_call = mock_exporter.record_histogram.call_args + assert hist_call[0][1] == pytest.approx(0.0) + + def test_error_path_increments_error_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_workflow_info(error="Something went wrong", workflow_run_status="failed") + trace_handler._workflow_trace(info) + + error_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.ERRORS + ] + assert len(error_calls) == 1 + + def test_no_error_counter_on_success(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._workflow_trace(make_workflow_info()) + + error_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.ERRORS + ] + assert len(error_calls) == 0 + + def test_parent_trace_context_injected_into_span_attrs(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_workflow_info( + metadata={ + "app_id": "app-001", + "tenant_id": "tenant-abc", + "parent_trace_context": { + "trace_id": "outer-trace", + "parent_node_execution_id": "outer-ne-001", + "parent_workflow_run_id": "outer-run-001", + "parent_app_id": "outer-app-001", + }, + } + ) + trace_handler._workflow_trace(info) + + attrs = mock_exporter.export_span.call_args[0][1] + assert attrs["dify.parent.trace_id"] == "outer-trace" + assert attrs["dify.parent.node.execution_id"] == "outer-ne-001" + assert attrs["dify.parent.workflow.run_id"] == "outer-run-001" + assert attrs["dify.parent.app.id"] == "outer-app-001" + + +# --------------------------------------------------------------------------- +# _node_execution_trace / _emit_node_execution_trace +# --------------------------------------------------------------------------- + + +class TestNodeExecutionTrace: + def test_emits_span_with_node_execution_span_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._node_execution_trace(make_node_info()) + + span_call = mock_exporter.export_span.call_args + assert span_call[0][0] == EnterpriseTelemetrySpan.NODE_EXECUTION + + def test_span_contains_core_node_attributes(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._node_execution_trace(make_node_info()) + + attrs = mock_exporter.export_span.call_args[0][1] + assert attrs["dify.node.execution_id"] == "ne-001" + assert attrs["dify.node.id"] == "node-001" + assert attrs["dify.node.type"] == "llm" + assert attrs["dify.node.status"] == "succeeded" + assert attrs["gen_ai.request.model"] == "gpt-4" + assert attrs["gen_ai.provider.name"] == "openai" + + def test_increments_token_counters_when_tokens_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._node_execution_trace(make_node_info()) + + counter_names = [c[0][0] for c in mock_exporter.increment_counter.call_args_list] + assert EnterpriseTelemetryCounter.TOKENS in counter_names + assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names + assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names + + def test_no_token_counters_when_total_tokens_zero(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._node_execution_trace(make_node_info(total_tokens=0)) + + counter_names = [c[0][0] for c in mock_exporter.increment_counter.call_args_list] + assert EnterpriseTelemetryCounter.TOKENS not in counter_names + assert EnterpriseTelemetryCounter.INPUT_TOKENS not in counter_names + + def test_records_node_duration_histogram(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._node_execution_trace(make_node_info()) + + hist_call = mock_exporter.record_histogram.call_args + assert hist_call[0][0] == EnterpriseTelemetryHistogram.NODE_DURATION + assert hist_call[0][1] == pytest.approx(2.5) + + def test_error_path_increments_error_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._node_execution_trace(make_node_info(error="Node failed", status="failed")) + + error_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.ERRORS + ] + assert len(error_calls) == 1 + + def test_emits_companion_log_with_span_name_as_event(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: + trace_handler._node_execution_trace(make_node_info()) + + mock_log.assert_called_once() + assert mock_log.call_args[1]["event_name"] == EnterpriseTelemetrySpan.NODE_EXECUTION.value + + def test_plugin_name_added_to_duration_labels_for_tool_node(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_node_info( + node_type="tool", + metadata={ + "app_id": "app-001", + "tenant_id": "tenant-abc", + "plugin_name": "my-plugin", + }, + ) + trace_handler._node_execution_trace(info) + + hist_call = mock_exporter.record_histogram.call_args + duration_labels = hist_call[0][2] + assert duration_labels.get("plugin_name") == "my-plugin" + + def test_plugin_name_not_added_for_non_tool_node(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_node_info( + node_type="llm", + metadata={ + "app_id": "app-001", + "tenant_id": "tenant-abc", + "plugin_name": "my-plugin", + }, + ) + trace_handler._node_execution_trace(info) + + hist_call = mock_exporter.record_histogram.call_args + duration_labels = hist_call[0][2] + assert "plugin_name" not in duration_labels + + def test_companion_log_inputs_use_ref_when_content_disabled(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: + trace_handler._node_execution_trace( + make_node_info(node_inputs={"prompt": "hello"}, node_outputs={"text": "world"}) + ) + + log_attrs = mock_log.call_args[1]["attributes"] + assert log_attrs["dify.node.inputs"].startswith("ref:node_execution_id=") + assert log_attrs["dify.node.outputs"].startswith("ref:node_execution_id=") + + +# --------------------------------------------------------------------------- +# _draft_node_execution_trace +# --------------------------------------------------------------------------- + + +class TestDraftNodeExecutionTrace: + def test_uses_draft_span_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + trace_handler._draft_node_execution_trace(make_draft_node_info()) + + span_call = mock_exporter.export_span.call_args + assert span_call[0][0] == EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION + + def test_correlation_id_is_node_execution_id(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_draft_node_info() + trace_handler._draft_node_execution_trace(info) + + span_call = mock_exporter.export_span.call_args + assert span_call[1]["correlation_id"] == "ne-draft-001" + + def test_trace_correlation_override_is_workflow_run_id(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): + info = make_draft_node_info() + trace_handler._draft_node_execution_trace(info) + + span_call = mock_exporter.export_span.call_args + assert span_call[1]["trace_correlation_override"] == "run-draft-001" + + def test_companion_log_uses_draft_span_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: + trace_handler._draft_node_execution_trace(make_draft_node_info()) + + assert mock_log.call_args[1]["event_name"] == EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION.value + + +# --------------------------------------------------------------------------- +# _message_trace +# --------------------------------------------------------------------------- + + +class TestMessageTrace: + def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._message_trace(make_message_info()) + + mock_emit.assert_called_once() + assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.MESSAGE_RUN + + def test_emits_correct_tenant_and_user(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._message_trace(make_message_info()) + + assert mock_emit.call_args[1]["tenant_id"] == "tenant-abc" + + def test_duration_computed_from_timestamps(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._message_trace(make_message_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.message.duration"] == pytest.approx(5.0) + + def test_no_duration_when_timestamps_missing(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._message_trace(make_message_info(start_time=None, end_time=None)) + + attrs = mock_emit.call_args[1]["attributes"] + assert "dify.message.duration" not in attrs + + def test_records_duration_histogram_when_timestamps_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._message_trace(make_message_info()) + + hist_calls = [ + c + for c in mock_exporter.record_histogram.call_args_list + if c[0][0] == EnterpriseTelemetryHistogram.MESSAGE_DURATION + ] + assert len(hist_calls) == 1 + assert hist_calls[0][0][1] == pytest.approx(5.0) + + def test_no_duration_histogram_when_timestamps_missing(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._message_trace(make_message_info(start_time=None, end_time=None)) + + hist_names = [c[0][0] for c in mock_exporter.record_histogram.call_args_list] + assert EnterpriseTelemetryHistogram.MESSAGE_DURATION not in hist_names + + def test_records_ttft_histogram_when_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._message_trace(make_message_info(gen_ai_server_time_to_first_token=0.42)) + + ttft_calls = [ + c + for c in mock_exporter.record_histogram.call_args_list + if c[0][0] == EnterpriseTelemetryHistogram.MESSAGE_TTFT + ] + assert len(ttft_calls) == 1 + assert ttft_calls[0][0][1] == pytest.approx(0.42) + + def test_no_ttft_histogram_when_not_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._message_trace(make_message_info(gen_ai_server_time_to_first_token=None)) + + hist_names = [c[0][0] for c in mock_exporter.record_histogram.call_args_list] + assert EnterpriseTelemetryHistogram.MESSAGE_TTFT not in hist_names + + def test_increments_token_counters(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._message_trace(make_message_info()) + + counter_names = [c[0][0] for c in mock_exporter.increment_counter.call_args_list] + assert EnterpriseTelemetryCounter.TOKENS in counter_names + assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names + assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names + + def test_error_path_increments_error_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._message_trace(make_message_info(error="LLM failed")) + + error_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.ERRORS + ] + assert len(error_calls) == 1 + + def test_inputs_and_outputs_gated_by_include_content(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._message_trace(make_message_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.message.inputs"].startswith("ref:message_id=") + assert attrs["dify.message.outputs"].startswith("ref:message_id=") + + +# --------------------------------------------------------------------------- +# _tool_trace +# --------------------------------------------------------------------------- + + +class TestToolTrace: + def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._tool_trace(make_tool_info()) + + assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.TOOL_EXECUTION + + def test_status_is_succeeded_on_success(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._tool_trace(make_tool_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.tool.status"] == "succeeded" + + def test_status_is_failed_on_error(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._tool_trace(make_tool_info(error="Tool error")) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.tool.status"] == "failed" + + def test_records_tool_duration_histogram(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._tool_trace(make_tool_info()) + + hist_call = mock_exporter.record_histogram.call_args + assert hist_call[0][0] == EnterpriseTelemetryHistogram.TOOL_DURATION + assert hist_call[0][1] == pytest.approx(1.5) + + def test_error_increments_error_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._tool_trace(make_tool_info(error="Tool crashed")) + + error_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.ERRORS + ] + assert len(error_calls) == 1 + + def test_inputs_and_outputs_gated_by_include_content(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._tool_trace(make_tool_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.tool.inputs"].startswith("ref:message_id=") + assert attrs["dify.tool.outputs"].startswith("ref:message_id=") + + def test_inputs_present_when_include_content_true(self, trace_handler, mock_exporter): + mock_exporter.include_content = True + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._tool_trace(make_tool_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.tool.inputs"] == json.dumps({"query": "test"}) + assert attrs["dify.tool.outputs"] == "search results" + + def test_increments_requests_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._tool_trace(make_tool_info()) + + request_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.REQUESTS + ] + assert len(request_calls) == 1 + assert request_calls[0][0][2]["type"] == "tool" + + +# --------------------------------------------------------------------------- +# _moderation_trace +# --------------------------------------------------------------------------- + + +class TestModerationTrace: + def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._moderation_trace(make_moderation_info()) + + assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.MODERATION_CHECK + + def test_flagged_true_sets_attribute(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._moderation_trace(make_moderation_info(flagged=True)) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.moderation.flagged"] is True + + def test_flagged_false_sets_attribute(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._moderation_trace(make_moderation_info(flagged=False)) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.moderation.flagged"] is False + + def test_query_gated_by_include_content(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._moderation_trace(make_moderation_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.moderation.query"].startswith("ref:message_id=") + + def test_query_present_when_include_content_true(self, trace_handler, mock_exporter): + mock_exporter.include_content = True + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._moderation_trace(make_moderation_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.moderation.query"] == "is this ok?" + + def test_increments_requests_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._moderation_trace(make_moderation_info()) + + request_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.REQUESTS + ] + assert len(request_calls) == 1 + assert request_calls[0][0][2]["type"] == "moderation" + + +# --------------------------------------------------------------------------- +# _suggested_question_trace +# --------------------------------------------------------------------------- + + +class TestSuggestedQuestionTrace: + def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._suggested_question_trace(make_suggested_question_info()) + + assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION + + def test_duration_computed_from_timestamps(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._suggested_question_trace(make_suggested_question_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.suggested_question.duration"] == pytest.approx(5.0) + + def test_duration_is_none_when_timestamps_missing(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._suggested_question_trace(make_suggested_question_info(start_time=None, end_time=None)) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.suggested_question.duration"] is None + + def test_status_is_failed_when_error_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._suggested_question_trace(make_suggested_question_info(error="Generation failed")) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.suggested_question.status"] == "failed" + + def test_status_falls_back_to_succeeded_when_no_error(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._suggested_question_trace(make_suggested_question_info(status=None, error=None)) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.suggested_question.status"] == "succeeded" + + def test_question_count_attribute(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._suggested_question_trace(make_suggested_question_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.suggested_question.count"] == 2 + + def test_questions_gated_by_include_content(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._suggested_question_trace(make_suggested_question_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.suggested_question.questions"].startswith("ref:message_id=") + + def test_increments_requests_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._suggested_question_trace(make_suggested_question_info()) + + request_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.REQUESTS + ] + assert len(request_calls) == 1 + assert request_calls[0][0][2]["type"] == "suggested_question" + + +# --------------------------------------------------------------------------- +# _dataset_retrieval_trace +# --------------------------------------------------------------------------- + + +class TestDatasetRetrievalTrace: + def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) + + assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.DATASET_RETRIEVAL + + def test_document_count_attribute(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.retrieval.document_count"] == 1 + + def test_dataset_ids_extracted(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert "ds-001" in attrs["dify.dataset.id"] + + def test_empty_documents_has_zero_count(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info(documents=[])) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.retrieval.document_count"] == 0 + + def test_status_succeeded_when_no_error(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.retrieval.status"] == "succeeded" + + def test_status_failed_when_error_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info(error="DB error")) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.retrieval.status"] == "failed" + + def test_embedding_model_attributes_set_when_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert "dify.dataset.embedding_providers" in attrs + assert "dify.dataset.embedding_models" in attrs + + def test_no_embedding_model_attributes_when_not_provided(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace( + make_dataset_retrieval_info(metadata={"app_id": "app-001", "tenant_id": "tenant-abc"}) + ) + + attrs = mock_emit.call_args[1]["attributes"] + assert "dify.dataset.embedding_providers" not in attrs + assert "dify.dataset.embedding_models" not in attrs + + def test_rerank_attributes_set_when_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace( + make_dataset_retrieval_info( + metadata={ + "app_id": "app-001", + "tenant_id": "tenant-abc", + "rerank_model_provider": "cohere", + "rerank_model_name": "rerank-english", + } + ) + ) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.retrieval.rerank_provider"] == "cohere" + assert attrs["dify.retrieval.rerank_model"] == "rerank-english" + + def test_no_rerank_attributes_when_not_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace( + make_dataset_retrieval_info(metadata={"app_id": "app-001", "tenant_id": "tenant-abc"}) + ) + + attrs = mock_emit.call_args[1]["attributes"] + assert "dify.retrieval.rerank_provider" not in attrs + assert "dify.retrieval.rerank_model" not in attrs + + def test_dataset_retrieval_counter_incremented_per_dataset(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) + + ds_calls = [ + c + for c in mock_exporter.increment_counter.call_args_list + if c[0][0] == EnterpriseTelemetryCounter.DATASET_RETRIEVALS + ] + assert len(ds_calls) == 1 + assert ds_calls[0][0][2]["dataset_id"] == "ds-001" + + def test_no_dataset_retrieval_counter_when_no_documents(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info(documents=[])) + + ds_calls = [ + c + for c in mock_exporter.increment_counter.call_args_list + if c[0][0] == EnterpriseTelemetryCounter.DATASET_RETRIEVALS + ] + assert len(ds_calls) == 0 + + def test_query_gated_by_include_content(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.retrieval.query"].startswith("ref:message_id=") + + +# --------------------------------------------------------------------------- +# _generate_name_trace +# --------------------------------------------------------------------------- + + +class TestGenerateNameTrace: + def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._generate_name_trace(make_generate_name_info()) + + assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION + + def test_duration_computed_from_timestamps(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._generate_name_trace(make_generate_name_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.generate_name.duration"] == pytest.approx(5.0) + + def test_no_duration_when_timestamps_missing(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._generate_name_trace(make_generate_name_info(start_time=None, end_time=None)) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.generate_name.duration"] is None + + def test_status_succeeded_on_success(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._generate_name_trace(make_generate_name_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.generate_name.status"] == "succeeded" + + def test_status_failed_when_metadata_has_error(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._generate_name_trace( + make_generate_name_info( + metadata={ + "app_id": "app-001", + "tenant_id": "tenant-abc", + "error": "Name generation failed", + } + ) + ) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.generate_name.status"] == "failed" + + def test_inputs_and_outputs_gated_by_include_content(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._generate_name_trace(make_generate_name_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.generate_name.inputs"].startswith("ref:conversation_id=") + assert attrs["dify.generate_name.outputs"].startswith("ref:conversation_id=") + + def test_increments_requests_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._generate_name_trace(make_generate_name_info()) + + request_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.REQUESTS + ] + assert len(request_calls) == 1 + assert request_calls[0][0][2]["type"] == "generate_name" + + +# --------------------------------------------------------------------------- +# _prompt_generation_trace +# --------------------------------------------------------------------------- + + +class TestPromptGenerationTrace: + def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._prompt_generation_trace(make_prompt_generation_info()) + + assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION + + def test_status_succeeded_on_success(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._prompt_generation_trace(make_prompt_generation_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.prompt_generation.status"] == "succeeded" + + def test_status_failed_when_error_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._prompt_generation_trace(make_prompt_generation_info(error="Generation error")) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.prompt_generation.status"] == "failed" + + def test_token_counters_incremented(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._prompt_generation_trace(make_prompt_generation_info()) + + counter_names = [c[0][0] for c in mock_exporter.increment_counter.call_args_list] + assert EnterpriseTelemetryCounter.TOKENS in counter_names + assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names + assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names + + def test_records_duration_histogram(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._prompt_generation_trace(make_prompt_generation_info()) + + hist_calls = [ + c + for c in mock_exporter.record_histogram.call_args_list + if c[0][0] == EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION + ] + assert len(hist_calls) == 1 + assert hist_calls[0][0][1] == pytest.approx(3.2) + + def test_total_price_attribute_set_when_present(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._prompt_generation_trace(make_prompt_generation_info(total_price=0.05, currency="USD")) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.prompt_generation.total_price"] == pytest.approx(0.05) + assert attrs["dify.prompt_generation.currency"] == "USD" + + def test_no_total_price_attribute_when_none(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._prompt_generation_trace(make_prompt_generation_info(total_price=None)) + + attrs = mock_emit.call_args[1]["attributes"] + assert "dify.prompt_generation.total_price" not in attrs + + def test_error_increments_error_counter(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._prompt_generation_trace(make_prompt_generation_info(error="Prompt failed")) + + error_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.ERRORS + ] + assert len(error_calls) == 1 + + def test_no_error_counter_on_success(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._prompt_generation_trace(make_prompt_generation_info()) + + error_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.ERRORS + ] + assert len(error_calls) == 0 + + def test_instruction_gated_by_include_content(self, trace_handler, mock_exporter): + mock_exporter.include_content = False + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._prompt_generation_trace(make_prompt_generation_info()) + + attrs = mock_emit.call_args[1]["attributes"] + assert attrs["dify.prompt_generation.instruction"].startswith("ref:trace_id=") + + def test_operation_type_label_used_in_token_counters(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): + trace_handler._prompt_generation_trace(make_prompt_generation_info(operation_type="code_generate")) + + token_calls = [ + c for c in mock_exporter.increment_counter.call_args_list if c[0][0] == EnterpriseTelemetryCounter.TOKENS + ] + assert len(token_calls) == 1 + assert token_calls[0][0][2]["operation_type"] == "code_generate" + + def test_emits_correct_tenant_id(self, trace_handler, mock_exporter): + with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: + trace_handler._prompt_generation_trace(make_prompt_generation_info()) + + assert mock_emit.call_args[1]["tenant_id"] == "tenant-abc" diff --git a/api/tests/unit_tests/enterprise/telemetry/test_event_handlers.py b/api/tests/unit_tests/enterprise/telemetry/test_event_handlers.py new file mode 100644 index 0000000000..b70c0260d5 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_event_handlers.py @@ -0,0 +1,54 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry import event_handlers +from enterprise.telemetry.contracts import TelemetryCase + + +@pytest.fixture +def mock_gateway_emit(): + with patch("core.telemetry.gateway.emit") as mock: + yield mock + + +def test_handle_app_created_calls_task(mock_gateway_emit): + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + sender.mode = "chat" + + event_handlers._handle_app_created(sender) + + mock_gateway_emit.assert_called_once_with( + case=TelemetryCase.APP_CREATED, + context={"tenant_id": "tenant-456"}, + payload={"app_id": "app-123", "mode": "chat"}, + ) + + +def test_handle_app_created_no_exporter(mock_gateway_emit): + """Gateway handles exporter availability internally; handler always calls gateway.""" + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + + event_handlers._handle_app_created(sender) + + mock_gateway_emit.assert_called_once() + + +def test_handlers_create_valid_envelopes(mock_gateway_emit): + """Verify handlers pass correct TelemetryCase and payload structure.""" + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + sender.mode = "chat" + + event_handlers._handle_app_created(sender) + + call_kwargs = mock_gateway_emit.call_args[1] + assert call_kwargs["case"] == TelemetryCase.APP_CREATED + assert call_kwargs["context"]["tenant_id"] == "tenant-456" + assert call_kwargs["payload"]["app_id"] == "app-123" + assert call_kwargs["payload"]["mode"] == "chat" diff --git a/api/tests/unit_tests/enterprise/telemetry/test_exporter.py b/api/tests/unit_tests/enterprise/telemetry/test_exporter.py new file mode 100644 index 0000000000..6bdae13923 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_exporter.py @@ -0,0 +1,628 @@ +"""Unit tests for EnterpriseExporter and _ExporterFactory.""" + +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from configs.enterprise import EnterpriseTelemetryConfig +from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram +from enterprise.telemetry.exporter import EnterpriseExporter, _datetime_to_ns, _parse_otlp_headers + + +def test_config_api_key_default_empty(): + """Test that ENTERPRISE_OTLP_API_KEY defaults to empty string.""" + config = EnterpriseTelemetryConfig() + assert config.ENTERPRISE_OTLP_API_KEY == "" + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_api_key_only_injects_bearer_header(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that API key alone injects Bearer authorization header.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="test-secret-key", + ) + + EnterpriseExporter(mock_config) + + # Verify span exporter was called with Bearer header + assert mock_span_exporter.call_args is not None + headers = mock_span_exporter.call_args.kwargs.get("headers") + assert headers is not None + assert ("authorization", "Bearer test-secret-key") in headers + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_empty_api_key_no_auth_header(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that empty API key does not inject authorization header.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="", + ) + + EnterpriseExporter(mock_config) + + # Verify span exporter was called without authorization header + assert mock_span_exporter.call_args is not None + headers = mock_span_exporter.call_args.kwargs.get("headers") + # Headers should be None or not contain authorization + if headers is not None: + assert not any(key == "authorization" for key, _ in headers) + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_api_key_and_custom_headers_merge(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that API key and custom headers are merged correctly.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="x-custom=foo", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="test-key", + ) + + EnterpriseExporter(mock_config) + + # Verify both headers are present + assert mock_span_exporter.call_args is not None + headers = mock_span_exporter.call_args.kwargs.get("headers") + assert headers is not None + assert ("authorization", "Bearer test-key") in headers + assert ("x-custom", "foo") in headers + + +@patch("enterprise.telemetry.exporter.logger") +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_api_key_overrides_conflicting_header( + mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock, mock_logger: MagicMock +) -> None: + """Test that API key overrides conflicting authorization header and logs warning.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="authorization=Basic+old", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="test-key", + ) + + EnterpriseExporter(mock_config) + + # Verify Bearer header takes precedence + assert mock_span_exporter.call_args is not None + headers = mock_span_exporter.call_args.kwargs.get("headers") + assert headers is not None + assert ("authorization", "Bearer test-key") in headers + # Verify old authorization header is not present + assert ("authorization", "Basic old") not in headers + + # Verify warning was logged + mock_logger.warning.assert_called_once() + assert mock_logger.warning.call_args is not None + warning_message = mock_logger.warning.call_args[0][0] + assert "ENTERPRISE_OTLP_API_KEY is set" in warning_message + assert "authorization" in warning_message + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_https_endpoint_uses_secure_grpc(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that https:// endpoint enables TLS (insecure=False) for gRPC.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="test-key", + ) + + EnterpriseExporter(mock_config) + + # Verify insecure=False for both exporters (https:// scheme) + assert mock_span_exporter.call_args is not None + assert mock_span_exporter.call_args.kwargs["insecure"] is False + + assert mock_metric_exporter.call_args is not None + assert mock_metric_exporter.call_args.kwargs["insecure"] is False + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_http_endpoint_uses_insecure_grpc(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that http:// endpoint uses insecure gRPC (insecure=True).""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="http://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="", + ) + + EnterpriseExporter(mock_config) + + # Verify insecure=True for both exporters (http:// scheme) + assert mock_span_exporter.call_args is not None + assert mock_span_exporter.call_args.kwargs["insecure"] is True + + assert mock_metric_exporter.call_args is not None + assert mock_metric_exporter.call_args.kwargs["insecure"] is True + + +@patch("enterprise.telemetry.exporter.HTTPSpanExporter") +@patch("enterprise.telemetry.exporter.HTTPMetricExporter") +def test_insecure_not_passed_to_http_exporters(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that insecure parameter is not passed to HTTP exporters.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="http://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="http", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="test-key", + ) + + EnterpriseExporter(mock_config) + + # Verify insecure kwarg is NOT in HTTP exporter calls + assert mock_span_exporter.call_args is not None + assert "insecure" not in mock_span_exporter.call_args.kwargs + + assert mock_metric_exporter.call_args is not None + assert "insecure" not in mock_metric_exporter.call_args.kwargs + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_api_key_with_special_chars_preserved(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that API key with special characters is preserved without mangling.""" + special_key = "abc+def/ghi=jkl==" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY=special_key, + ) + + EnterpriseExporter(mock_config) + + # Verify special characters are preserved in Bearer header + assert mock_span_exporter.call_args is not None + headers = mock_span_exporter.call_args.kwargs.get("headers") + assert headers is not None + assert ("authorization", f"Bearer {special_key}") in headers + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_no_scheme_localhost_uses_insecure(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that endpoint without scheme defaults to insecure for localhost.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="localhost:4317", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="", + ) + + EnterpriseExporter(mock_config) + + # Verify insecure=True for localhost without scheme + assert mock_span_exporter.call_args is not None + assert mock_span_exporter.call_args.kwargs["insecure"] is True + + assert mock_metric_exporter.call_args is not None + assert mock_metric_exporter.call_args.kwargs["insecure"] is True + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_no_scheme_production_uses_insecure(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that endpoint without scheme defaults to insecure (not https://).""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="collector.example.com:4317", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="", + ) + + EnterpriseExporter(mock_config) + + # Verify insecure=True for any endpoint without https:// scheme + assert mock_span_exporter.call_args is not None + assert mock_span_exporter.call_args.kwargs["insecure"] is True + + assert mock_metric_exporter.call_args is not None + assert mock_metric_exporter.call_args.kwargs["insecure"] is True + + +# --------------------------------------------------------------------------- +# _parse_otlp_headers (line 55 — pair without "=" is skipped) +# --------------------------------------------------------------------------- + + +def test_parse_otlp_headers_empty_returns_empty_dict() -> None: + assert _parse_otlp_headers("") == {} + + +def test_parse_otlp_headers_value_may_contain_equals() -> None: + result = _parse_otlp_headers("token=abc=def==") + assert result == {"token": "abc=def=="} + + +def test_parse_otlp_headers_url_encoded() -> None: + result = _parse_otlp_headers("key=%E4%BD%A0%E5%A5%BD") + + assert result == {"key": "你好"} + + +# --------------------------------------------------------------------------- +# _datetime_to_ns (lines 64-68) +# --------------------------------------------------------------------------- + + +def test_datetime_to_ns_naive_treated_as_utc() -> None: + """Naive datetime must be interpreted as UTC (line 64-65).""" + naive = datetime(2024, 1, 1, 0, 0, 0) # no tzinfo + aware_utc = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + assert _datetime_to_ns(naive) == _datetime_to_ns(aware_utc) + + +def test_datetime_to_ns_tz_aware_converted_to_utc() -> None: + """Timezone-aware datetime must be converted to UTC before computing ns (line 66-67).""" + import zoneinfo + + eastern = zoneinfo.ZoneInfo("America/New_York") + dt_east = datetime(2024, 6, 1, 12, 0, 0, tzinfo=eastern) # UTC-4 in summer + dt_utc = dt_east.astimezone(UTC) + assert _datetime_to_ns(dt_east) == _datetime_to_ns(dt_utc) + + +def test_datetime_to_ns_returns_integer_nanoseconds() -> None: + dt = datetime(2024, 1, 1, 0, 0, 1, tzinfo=UTC) + result = _datetime_to_ns(dt) + # 2024-01-01 00:00:01 UTC = epoch + some_seconds; result should be in nanoseconds + assert isinstance(result, int) + # 1 second past epoch start of 2024 — should be > 1_700_000_000_000_000_000 (rough lower bound) + assert result > 1_700_000_000_000_000_000 + + +# --------------------------------------------------------------------------- +# EnterpriseExporter constructor — include_content property (line 115 / 288-289) +# --------------------------------------------------------------------------- + + +def _make_grpc_config(**overrides) -> SimpleNamespace: + defaults = { + "ENTERPRISE_OTLP_ENDPOINT": "https://collector.example.com", + "ENTERPRISE_OTLP_HEADERS": "", + "ENTERPRISE_OTLP_PROTOCOL": "grpc", + "ENTERPRISE_SERVICE_NAME": "dify", + "ENTERPRISE_OTEL_SAMPLING_RATE": 1.0, + "ENTERPRISE_INCLUDE_CONTENT": True, + "ENTERPRISE_OTLP_API_KEY": "", + } + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_include_content_true_stored_on_exporter( + mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock +) -> None: + """include_content=True is stored as a public attribute (line 115).""" + exporter = EnterpriseExporter(_make_grpc_config(ENTERPRISE_INCLUDE_CONTENT=True)) + assert exporter.include_content is True + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_include_content_false_stored_on_exporter( + mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock +) -> None: + """include_content=False is preserved (lines 288-289 path exercised by callers).""" + exporter = EnterpriseExporter(_make_grpc_config(ENTERPRISE_INCLUDE_CONTENT=False)) + assert exporter.include_content is False + + +# --------------------------------------------------------------------------- +# EnterpriseExporter constructor — gRPC setup (lines 64-68 exporter-init path) +# --------------------------------------------------------------------------- + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_grpc_exporter_created_with_correct_endpoint( + mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock +) -> None: + """GRPCSpanExporter and GRPCMetricExporter receive the configured endpoint.""" + EnterpriseExporter(_make_grpc_config(ENTERPRISE_OTLP_ENDPOINT="https://my-collector:4317")) + + assert mock_span_exporter.call_args.kwargs["endpoint"] == "https://my-collector:4317" + assert mock_metric_exporter.call_args.kwargs["endpoint"] == "https://my-collector:4317" + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_grpc_exporter_empty_endpoint_passes_none( + mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock +) -> None: + """Empty string endpoint is normalised to None for both gRPC exporters.""" + EnterpriseExporter(_make_grpc_config(ENTERPRISE_OTLP_ENDPOINT="")) + + assert mock_span_exporter.call_args.kwargs["endpoint"] is None + assert mock_metric_exporter.call_args.kwargs["endpoint"] is None + + +# --------------------------------------------------------------------------- +# EnterpriseExporter.export_span (lines 204-271) +# --------------------------------------------------------------------------- + + +def _make_exporter_with_mock_tracer() -> tuple[EnterpriseExporter, MagicMock, MagicMock]: + """Return (exporter, mock_tracer, mock_span) with OTEL internals fully mocked.""" + mock_span = MagicMock() + mock_span.__enter__ = MagicMock(return_value=mock_span) + mock_span.__exit__ = MagicMock(return_value=False) + + mock_tracer = MagicMock() + mock_tracer.start_as_current_span.return_value = mock_span + + with ( + patch("enterprise.telemetry.exporter.GRPCSpanExporter"), + patch("enterprise.telemetry.exporter.GRPCMetricExporter"), + ): + exporter = EnterpriseExporter(_make_grpc_config()) + + exporter._tracer = mock_tracer + return exporter, mock_tracer, mock_span + + +@patch("enterprise.telemetry.exporter.set_correlation_id") +@patch("enterprise.telemetry.exporter.set_span_id_source") +def test_export_span_sets_and_clears_context(mock_set_span: MagicMock, mock_set_corr: MagicMock) -> None: + """export_span sets correlation/span context before the span and clears them in finally.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + exporter.export_span( + name="test.span", + attributes={"k": "v"}, + correlation_id="corr-1", + span_id_source="span-src-1", + ) + + # Context was set at the start of the call + mock_set_corr.assert_any_call("corr-1") + mock_set_span.assert_any_call("span-src-1") + # Context was cleared in finally + mock_set_corr.assert_called_with(None) + mock_set_span.assert_called_with(None) + + +def test_export_span_sets_attributes_on_span() -> None: + """All non-None attribute values are set on the span via set_attribute.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + exporter.export_span( + name="test.span", + attributes={"key1": "value1", "key2": None, "key3": 42}, + ) + + # set_attribute should be called for non-None values only + calls = list(mock_span.set_attribute.call_args_list) + keys_set = {c[0][0] for c in calls} + assert "key1" in keys_set + assert "key3" in keys_set + assert "key2" not in keys_set + + +def test_export_span_no_end_time_uses_end_on_exit() -> None: + """When end_time is None, end_on_exit=True is passed to start_as_current_span.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + exporter.export_span(name="test.span", attributes={}) + + _, kwargs = mock_tracer.start_as_current_span.call_args + assert kwargs["end_on_exit"] is True + + +def test_export_span_with_end_time_calls_span_end() -> None: + """When end_time is provided, span.end() is called with the converted ns timestamp.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + start = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + end = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC) + + exporter.export_span(name="test.span", attributes={}, start_time=start, end_time=end) + + mock_span.end.assert_called_once() + end_ns = mock_span.end.call_args.kwargs["end_time"] + assert end_ns == _datetime_to_ns(end) + + +def test_export_span_with_start_time_passed_to_start_as_current_span() -> None: + """When start_time is provided it is converted to ns and passed to start_as_current_span.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + start = datetime(2024, 3, 1, 12, 0, 0, tzinfo=UTC) + exporter.export_span(name="test.span", attributes={}, start_time=start) + + _, kwargs = mock_tracer.start_as_current_span.call_args + assert kwargs["start_time"] == _datetime_to_ns(start) + + +def test_export_span_root_span_no_parent_context() -> None: + """When span_id_source == correlation_id the span is root — no parent context.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + uid = "123e4567-e89b-12d3-a456-426614174000" + exporter.export_span( + name="root.span", + attributes={}, + correlation_id=uid, + span_id_source=uid, + ) + + _, kwargs = mock_tracer.start_as_current_span.call_args + assert kwargs["context"] is None + + +def test_export_span_child_span_has_parent_context() -> None: + """When correlation_id != span_id_source the child span gets a parent context.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + corr_uid = "123e4567-e89b-12d3-a456-426614174000" + node_uid = "987fbc97-4bed-5078-9f07-9141ba07c9f3" + + exporter.export_span( + name="child.span", + attributes={}, + correlation_id=corr_uid, + span_id_source=node_uid, + ) + + _, kwargs = mock_tracer.start_as_current_span.call_args + assert kwargs["context"] is not None + + +def test_export_span_cross_workflow_parent_context() -> None: + """When parent_span_id_source is set, the cross-workflow parent context is built.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + corr_uid = "123e4567-e89b-12d3-a456-426614174000" + parent_uid = "987fbc97-4bed-5078-9f07-9141ba07c9f3" + + exporter.export_span( + name="cross.span", + attributes={}, + correlation_id=corr_uid, + parent_span_id_source=parent_uid, + ) + + _, kwargs = mock_tracer.start_as_current_span.call_args + assert kwargs["context"] is not None + + +@patch("enterprise.telemetry.exporter.logger") +def test_export_span_logs_exception_on_error(mock_logger: MagicMock) -> None: + """If the span block raises, the exception is logged and context is still cleared.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + mock_tracer.start_as_current_span.side_effect = RuntimeError("boom") + + exporter.export_span(name="bad.span", attributes={}) # must not raise + + mock_logger.exception.assert_called_once() + assert "bad.span" in mock_logger.exception.call_args[0][1] + + +@patch("enterprise.telemetry.exporter.logger") +def test_export_span_invalid_trace_correlation_logs_warning(mock_logger: MagicMock) -> None: + """Invalid UUID for trace_correlation_override triggers a warning log.""" + exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() + + parent_uid = "987fbc97-4bed-5078-9f07-9141ba07c9f3" + exporter.export_span( + name="link.span", + attributes={}, + correlation_id="not-a-valid-uuid", + parent_span_id_source=parent_uid, + ) + + mock_logger.warning.assert_called() + + +# --------------------------------------------------------------------------- +# EnterpriseExporter.increment_counter (lines 276-278) +# --------------------------------------------------------------------------- + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_increment_counter_calls_add_on_counter(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """increment_counter calls .add() on the matching counter instrument.""" + exporter = EnterpriseExporter(_make_grpc_config()) + + mock_counter = MagicMock() + exporter._counters[EnterpriseTelemetryCounter.TOKENS] = mock_counter + + labels = {"tenant_id": "t1", "app_id": "app-1"} + exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, 50, labels) + + mock_counter.add.assert_called_once_with(50, labels) + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_increment_counter_unknown_name_is_noop(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """increment_counter silently does nothing when the counter is not found.""" + exporter = EnterpriseExporter(_make_grpc_config()) + exporter._counters.clear() + + # Should not raise + exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, 5, {}) + + +# --------------------------------------------------------------------------- +# EnterpriseExporter.record_histogram (lines 283-285) +# --------------------------------------------------------------------------- + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_record_histogram_calls_record_on_histogram( + mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock +) -> None: + """record_histogram calls .record() on the matching histogram instrument.""" + exporter = EnterpriseExporter(_make_grpc_config()) + + mock_histogram = MagicMock() + exporter._histograms[EnterpriseTelemetryHistogram.WORKFLOW_DURATION] = mock_histogram + + labels = {"tenant_id": "t1"} + exporter.record_histogram(EnterpriseTelemetryHistogram.WORKFLOW_DURATION, 3.14, labels) + + mock_histogram.record.assert_called_once_with(3.14, labels) + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_record_histogram_unknown_name_is_noop(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """record_histogram silently does nothing when the histogram is not found.""" + exporter = EnterpriseExporter(_make_grpc_config()) + exporter._histograms.clear() + + # Should not raise + exporter.record_histogram(EnterpriseTelemetryHistogram.WORKFLOW_DURATION, 1.0, {}) diff --git a/api/tests/unit_tests/enterprise/telemetry/test_gateway.py b/api/tests/unit_tests/enterprise/telemetry/test_gateway.py new file mode 100644 index 0000000000..7e6ae64693 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_gateway.py @@ -0,0 +1,272 @@ +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.gateway import ( + CASE_ROUTING, + CASE_TO_TRACE_TASK, + PAYLOAD_SIZE_THRESHOLD_BYTES, + emit, +) +from enterprise.telemetry.contracts import SignalType, TelemetryCase, TelemetryEnvelope + + +class TestCaseRoutingTable: + def test_all_cases_have_routing(self) -> None: + for case in TelemetryCase: + assert case in CASE_ROUTING, f"Missing routing for {case}" + + def test_trace_cases(self) -> None: + trace_cases = [ + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + ] + for case in trace_cases: + assert CASE_ROUTING[case].signal_type is SignalType.TRACE, f"{case} should be trace" + + def test_metric_log_cases(self) -> None: + metric_log_cases = [ + TelemetryCase.APP_CREATED, + TelemetryCase.APP_UPDATED, + TelemetryCase.APP_DELETED, + TelemetryCase.FEEDBACK_CREATED, + ] + for case in metric_log_cases: + assert CASE_ROUTING[case].signal_type is SignalType.METRIC_LOG, f"{case} should be metric_log" + + def test_ce_eligible_cases(self) -> None: + ce_eligible_cases = [ + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + TelemetryCase.TOOL_EXECUTION, + TelemetryCase.MODERATION_CHECK, + TelemetryCase.SUGGESTED_QUESTION, + TelemetryCase.DATASET_RETRIEVAL, + TelemetryCase.GENERATE_NAME, + ] + for case in ce_eligible_cases: + assert CASE_ROUTING[case].ce_eligible is True, f"{case} should be CE eligible" + + def test_enterprise_only_cases(self) -> None: + enterprise_only_cases = [ + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + ] + for case in enterprise_only_cases: + assert CASE_ROUTING[case].ce_eligible is False, f"{case} should be enterprise-only" + + def test_trace_cases_have_task_name_mapping(self) -> None: + trace_cases = [c for c in TelemetryCase if CASE_ROUTING[c].signal_type is SignalType.TRACE] + for case in trace_cases: + assert case in CASE_TO_TRACE_TASK, f"Missing TraceTaskName mapping for {case}" + + +@pytest.fixture +def mock_ops_trace_manager(): + mock_module = MagicMock() + mock_trace_task_class = MagicMock() + mock_trace_task_class.return_value = MagicMock() + mock_module.TraceTask = mock_trace_task_class + mock_module.TraceQueueManager = MagicMock() + + mock_trace_entity = MagicMock() + mock_trace_task_name = MagicMock() + mock_trace_task_name.return_value = "workflow" + mock_trace_entity.TraceTaskName = mock_trace_task_name + + with ( + patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}), + patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}), + ): + yield mock_module, mock_trace_entity + + +class TestGatewayTraceRouting: + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_trace_case_routes_to_trace_manager( + self, + mock_ee_enabled: MagicMock, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False) + def test_ce_eligible_trace_enqueued_when_ee_disabled( + self, + mock_ee_enabled: MagicMock, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False) + def test_enterprise_only_trace_dropped_when_ee_disabled( + self, + mock_ee_enabled: MagicMock, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_enterprise_only_trace_enqueued_when_ee_enabled( + self, + mock_ee_enabled: MagicMock, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayMetricLogRouting: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_metric_case_routes_to_celery_task( + self, + mock_delay: MagicMock, + mock_ee_enabled: MagicMock, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc", "name": "My App"} + + emit(TelemetryCase.APP_CREATED, context, payload) + + mock_delay.assert_called_once() + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.case == TelemetryCase.APP_CREATED + assert envelope.tenant_id == "tenant-123" + assert envelope.payload["app_id"] == "app-abc" + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_envelope_has_unique_event_id( + self, + mock_delay: MagicMock, + mock_ee_enabled: MagicMock, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc"} + + emit(TelemetryCase.APP_CREATED, context, payload) + emit(TelemetryCase.APP_CREATED, context, payload) + + assert mock_delay.call_count == 2 + envelope1 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[0][0][0]) + envelope2 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[1][0][0]) + assert envelope1.event_id != envelope2.event_id + + +class TestGatewayPayloadSizing: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_small_payload_inlined( + self, + mock_delay: MagicMock, + mock_ee_enabled: MagicMock, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"key": "small_value"} + + emit(TelemetryCase.APP_CREATED, context, payload) + + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.payload == payload + assert envelope.metadata is None + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + @patch("core.telemetry.gateway.storage") + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_large_payload_stored( + self, + mock_delay: MagicMock, + mock_storage: MagicMock, + mock_ee_enabled: MagicMock, + ) -> None: + context = {"tenant_id": "tenant-123"} + large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000) + payload = {"key": large_value} + + emit(TelemetryCase.APP_CREATED, context, payload) + + mock_storage.save.assert_called_once() + storage_key = mock_storage.save.call_args[0][0] + assert storage_key.startswith("telemetry/tenant-123/") + + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.payload == {} + assert envelope.metadata is not None + assert envelope.metadata["payload_ref"] == storage_key + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + @patch("core.telemetry.gateway.storage") + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_large_payload_fallback_on_storage_error( + self, + mock_delay: MagicMock, + mock_storage: MagicMock, + mock_ee_enabled: MagicMock, + ) -> None: + mock_storage.save.side_effect = Exception("Storage failure") + context = {"tenant_id": "tenant-123"} + large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000) + payload = {"key": large_value} + + emit(TelemetryCase.APP_CREATED, context, payload) + + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.payload == payload + assert envelope.metadata is None + + +class TestTraceTaskNameMapping: + def test_workflow_run_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.WORKFLOW_RUN] is TraceTaskName.WORKFLOW_TRACE + + def test_message_run_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.MESSAGE_RUN] is TraceTaskName.MESSAGE_TRACE + + def test_node_execution_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.NODE_EXECUTION] is TraceTaskName.NODE_EXECUTION_TRACE + + def test_draft_node_execution_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.DRAFT_NODE_EXECUTION] is TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + + def test_prompt_generation_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.PROMPT_GENERATION] is TraceTaskName.PROMPT_GENERATION_TRACE diff --git a/api/tests/unit_tests/enterprise/telemetry/test_id_generator.py b/api/tests/unit_tests/enterprise/telemetry/test_id_generator.py new file mode 100644 index 0000000000..dc2be14ebf --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_id_generator.py @@ -0,0 +1,201 @@ +"""Unit tests for enterprise/telemetry/id_generator.py.""" + +from __future__ import annotations + +import uuid +from unittest.mock import patch + +# --------------------------------------------------------------------------- +# compute_deterministic_span_id +# --------------------------------------------------------------------------- + + +class TestComputeDeterministicSpanId: + def test_returns_lower_64_bits_of_uuid(self) -> None: + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + uid = "123e4567-e89b-12d3-a456-426614174000" + expected = uuid.UUID(uid).int & ((1 << 64) - 1) + assert compute_deterministic_span_id(uid) == expected + + def test_non_zero_result_returned_unchanged(self) -> None: + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + # This UUID has non-zero lower 64 bits + uid = "123e4567-e89b-12d3-a456-426614174000" + result = compute_deterministic_span_id(uid) + assert result != 0 + + def test_zero_lower_bits_returns_one(self) -> None: + """When the lower 64 bits of the UUID int are 0, the function must return 1 (OTEL requirement).""" + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + # Craft a UUID whose lower 64 bits are 0: upper 64 bits are 1, lower 64 bits are 0. + # int = (1 << 64), UUID fields constructed to produce this integer. + target_int = 1 << 64 # lower 64 bits are 0x0000000000000000 + crafted_uuid = uuid.UUID(int=target_int) + result = compute_deterministic_span_id(str(crafted_uuid)) + assert result == 1 + + def test_raises_on_invalid_uuid(self) -> None: + import pytest + + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + with pytest.raises((ValueError, AttributeError)): + compute_deterministic_span_id("not-a-uuid") + + def test_different_uuids_produce_different_span_ids(self) -> None: + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + uid1 = "123e4567-e89b-12d3-a456-426614174000" + uid2 = "987fbc97-4bed-5078-9f07-9141ba07c9f3" + assert compute_deterministic_span_id(uid1) != compute_deterministic_span_id(uid2) + + def test_deterministic_same_input_same_output(self) -> None: + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + uid = "123e4567-e89b-12d3-a456-426614174000" + assert compute_deterministic_span_id(uid) == compute_deterministic_span_id(uid) + + +# --------------------------------------------------------------------------- +# Context variable helpers +# --------------------------------------------------------------------------- + + +class TestContextVariableHelpers: + def test_set_and_get_correlation_id(self) -> None: + from enterprise.telemetry.id_generator import get_correlation_id, set_correlation_id + + set_correlation_id("corr-123") + assert get_correlation_id() == "corr-123" + + def test_clear_correlation_id(self) -> None: + from enterprise.telemetry.id_generator import get_correlation_id, set_correlation_id + + set_correlation_id("corr-abc") + set_correlation_id(None) + assert get_correlation_id() is None + + def test_correlation_id_default_is_none(self) -> None: + from enterprise.telemetry.id_generator import get_correlation_id, set_correlation_id + + set_correlation_id(None) + assert get_correlation_id() is None + + def test_set_span_id_source_stored_in_context(self) -> None: + from enterprise.telemetry.id_generator import _span_id_source_context, set_span_id_source + + set_span_id_source("span-src-1") + assert _span_id_source_context.get() == "span-src-1" + + def test_clear_span_id_source(self) -> None: + from enterprise.telemetry.id_generator import _span_id_source_context, set_span_id_source + + set_span_id_source("span-src-1") + set_span_id_source(None) + assert _span_id_source_context.get() is None + + +# --------------------------------------------------------------------------- +# CorrelationIdGenerator.generate_trace_id +# --------------------------------------------------------------------------- + + +class TestCorrelationIdGeneratorGenerateTraceId: + def setup_method(self) -> None: + from enterprise.telemetry.id_generator import set_correlation_id + + set_correlation_id(None) + + def test_returns_uuid_int_when_correlation_id_set(self) -> None: + from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_correlation_id + + uid = "123e4567-e89b-12d3-a456-426614174000" + set_correlation_id(uid) + gen = CorrelationIdGenerator() + trace_id = gen.generate_trace_id() + assert trace_id == uuid.UUID(uid).int + + def test_returns_random_when_no_correlation_id(self) -> None: + from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_correlation_id + + set_correlation_id(None) + gen = CorrelationIdGenerator() + # Should return a non-zero int without raising + trace_id = gen.generate_trace_id() + assert isinstance(trace_id, int) + assert trace_id > 0 + + def test_returns_random_when_correlation_id_is_invalid_uuid(self) -> None: + from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_correlation_id + + set_correlation_id("not-a-valid-uuid") + gen = CorrelationIdGenerator() + with patch("enterprise.telemetry.id_generator.random.getrandbits", return_value=42) as mock_rng: + trace_id = gen.generate_trace_id() + mock_rng.assert_called_once_with(128) + assert trace_id == 42 + + +# --------------------------------------------------------------------------- +# CorrelationIdGenerator.generate_span_id +# --------------------------------------------------------------------------- + + +class TestCorrelationIdGeneratorGenerateSpanId: + def setup_method(self) -> None: + from enterprise.telemetry.id_generator import set_span_id_source + + set_span_id_source(None) + + def test_uses_deterministic_span_id_when_source_set(self) -> None: + from enterprise.telemetry.id_generator import ( + CorrelationIdGenerator, + compute_deterministic_span_id, + set_span_id_source, + ) + + uid = "123e4567-e89b-12d3-a456-426614174000" + set_span_id_source(uid) + gen = CorrelationIdGenerator() + span_id = gen.generate_span_id() + assert span_id == compute_deterministic_span_id(uid) + + def test_returns_random_when_no_source(self) -> None: + from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source + + set_span_id_source(None) + gen = CorrelationIdGenerator() + span_id = gen.generate_span_id() + assert isinstance(span_id, int) + assert span_id != 0 + + def test_returns_random_when_source_is_invalid_uuid(self) -> None: + from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source + + set_span_id_source("not-a-uuid") + gen = CorrelationIdGenerator() + with patch("enterprise.telemetry.id_generator.random.getrandbits", return_value=7) as mock_rng: + span_id = gen.generate_span_id() + assert span_id == 7 + + def test_random_span_id_retried_if_zero(self) -> None: + """generate_span_id must never return 0 — it retries until non-zero.""" + from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source + + set_span_id_source(None) + gen = CorrelationIdGenerator() + # First call returns 0 (invalid), second returns 99 + with patch("enterprise.telemetry.id_generator.random.getrandbits", side_effect=[0, 99]): + span_id = gen.generate_span_id() + assert span_id == 99 + + def test_generate_span_id_always_non_zero(self) -> None: + from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source + + set_span_id_source(None) + gen = CorrelationIdGenerator() + for _ in range(20): + assert gen.generate_span_id() != 0 diff --git a/api/tests/unit_tests/enterprise/telemetry/test_metric_handler.py b/api/tests/unit_tests/enterprise/telemetry/test_metric_handler.py new file mode 100644 index 0000000000..56c42a57d5 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_metric_handler.py @@ -0,0 +1,511 @@ +"""Unit tests for EnterpriseMetricHandler.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from enterprise.telemetry.metric_handler import EnterpriseMetricHandler + + +@pytest.fixture +def mock_redis(): + with patch("enterprise.telemetry.metric_handler.redis_client") as mock: + yield mock + + +@pytest.fixture +def sample_envelope(): + return TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-123", + payload={"app_id": "app-123", "name": "Test App"}, + ) + + +def test_dispatch_app_created(sample_envelope, mock_redis): + mock_redis.set.return_value = True + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_created") as mock_handler: + handler.handle(sample_envelope) + mock_handler.assert_called_once_with(sample_envelope) + + +def test_dispatch_app_updated(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_UPDATED, + tenant_id="test-tenant", + event_id="test-event-456", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_updated") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_app_deleted(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_DELETED, + tenant_id="test-tenant", + event_id="test-event-789", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_deleted") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_feedback_created(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.FEEDBACK_CREATED, + tenant_id="test-tenant", + event_id="test-event-abc", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_feedback_created") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_message_run(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.MESSAGE_RUN, + tenant_id="test-tenant", + event_id="test-event-msg", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_message_run") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_tool_execution(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.TOOL_EXECUTION, + tenant_id="test-tenant", + event_id="test-event-tool", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_tool_execution") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_moderation_check(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.MODERATION_CHECK, + tenant_id="test-tenant", + event_id="test-event-mod", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_moderation_check") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_suggested_question(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.SUGGESTED_QUESTION, + tenant_id="test-tenant", + event_id="test-event-sq", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_suggested_question") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_dataset_retrieval(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.DATASET_RETRIEVAL, + tenant_id="test-tenant", + event_id="test-event-ds", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_dataset_retrieval") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_generate_name(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.GENERATE_NAME, + tenant_id="test-tenant", + event_id="test-event-gn", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_generate_name") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_prompt_generation(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.PROMPT_GENERATION, + tenant_id="test-tenant", + event_id="test-event-pg", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_prompt_generation") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_all_known_cases_have_handlers(mock_redis): + mock_redis.set.return_value = True + handler = EnterpriseMetricHandler() + + for case in TelemetryCase: + envelope = TelemetryEnvelope( + case=case, + tenant_id="test-tenant", + event_id=f"test-{case.value}", + payload={}, + ) + handler.handle(envelope) + + +def test_idempotency_duplicate(sample_envelope, mock_redis): + mock_redis.set.return_value = None + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_created") as mock_handler: + handler.handle(sample_envelope) + mock_handler.assert_not_called() + + +def test_idempotency_first_seen(sample_envelope, mock_redis): + mock_redis.set.return_value = True + + handler = EnterpriseMetricHandler() + is_dup = handler._is_duplicate(sample_envelope) + + assert is_dup is False + mock_redis.set.assert_called_once_with( + "telemetry:dedup:test-tenant:test-event-123", + b"1", + nx=True, + ex=3600, + ) + + +def test_idempotency_redis_failure_fails_open(sample_envelope, mock_redis, caplog): + mock_redis.set.side_effect = Exception("Redis unavailable") + + handler = EnterpriseMetricHandler() + is_dup = handler._is_duplicate(sample_envelope) + + assert is_dup is False + assert "Redis unavailable for deduplication check" in caplog.text + + +def test_rehydration_uses_payload(sample_envelope): + handler = EnterpriseMetricHandler() + payload = handler._rehydrate(sample_envelope) + + assert payload == {"app_id": "app-123", "name": "Test App"} + + +def test_rehydration_from_storage(): + """Verify _rehydrate loads payload from object storage via payload_ref.""" + stored_data = {"app_id": "app-stored", "mode": "workflow"} + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-fb", + payload={}, + metadata={"payload_ref": "telemetry/test-tenant/test-event-fb.json"}, + ) + + handler = EnterpriseMetricHandler() + with patch("enterprise.telemetry.metric_handler.storage") as mock_storage: + mock_storage.load.return_value = json.dumps(stored_data).encode("utf-8") + payload = handler._rehydrate(envelope) + + assert payload == stored_data + mock_storage.load.assert_called_once_with("telemetry/test-tenant/test-event-fb.json") + + +def test_rehydration_storage_failure_emits_degraded_event(): + """Verify _rehydrate emits degraded event when storage load fails.""" + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-fail", + payload={}, + metadata={"payload_ref": "telemetry/test-tenant/test-event-fail.json"}, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("enterprise.telemetry.metric_handler.storage") as mock_storage, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_storage.load.side_effect = Exception("Storage unavailable") + payload = handler._rehydrate(envelope) + + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + + assert payload == {} + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.REHYDRATION_FAILED + assert "dify.telemetry.error" in call_args[1]["attributes"] + + +def test_rehydration_emits_degraded_event_on_empty_payload(): + """Verify _rehydrate emits degraded event when payload is empty and no ref exists.""" + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-empty", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit: + payload = handler._rehydrate(envelope) + + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + + assert payload == {} + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.REHYDRATION_FAILED + assert "dify.telemetry.error" in call_args[1]["attributes"] + + +def test_on_app_created_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="tenant-123", + event_id="event-456", + payload={"app_id": "app-789", "mode": "chat"}, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_get_exporter.return_value = mock_exporter + + handler._on_app_created(envelope) + + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.APP_CREATED + assert call_args[1]["tenant_id"] == "tenant-123" + attrs = call_args[1]["attributes"] + assert attrs["dify.app_id"] == "app-789" + assert attrs["dify.tenant_id"] == "tenant-123" + assert attrs["dify.event.id"] == "event-456" + assert attrs["dify.app.mode"] == "chat" + assert "dify.app.created_at" in attrs + + from enterprise.telemetry.entities import EnterpriseTelemetryCounter + + mock_exporter.increment_counter.assert_called_once() + counter_call = mock_exporter.increment_counter.call_args + assert counter_call[0][0] == EnterpriseTelemetryCounter.APP_CREATED + assert counter_call[0][1] == 1 + assert counter_call[0][2]["tenant_id"] == "tenant-123" + assert counter_call[0][2]["app_id"] == "app-789" + assert counter_call[0][2]["mode"] == "chat" + + +def test_on_app_updated_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_UPDATED, + tenant_id="tenant-123", + event_id="event-456", + payload={"app_id": "app-789"}, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_get_exporter.return_value = mock_exporter + + handler._on_app_updated(envelope) + + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.APP_UPDATED + assert call_args[1]["tenant_id"] == "tenant-123" + attrs = call_args[1]["attributes"] + assert attrs["dify.app_id"] == "app-789" + assert attrs["dify.tenant_id"] == "tenant-123" + assert attrs["dify.event.id"] == "event-456" + assert "dify.app.updated_at" in attrs + + from enterprise.telemetry.entities import EnterpriseTelemetryCounter + + mock_exporter.increment_counter.assert_called_once() + counter_call = mock_exporter.increment_counter.call_args + assert counter_call[0][0] == EnterpriseTelemetryCounter.APP_UPDATED + assert counter_call[0][1] == 1 + assert counter_call[0][2]["tenant_id"] == "tenant-123" + assert counter_call[0][2]["app_id"] == "app-789" + + +def test_on_app_deleted_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_DELETED, + tenant_id="tenant-123", + event_id="event-456", + payload={"app_id": "app-789"}, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_get_exporter.return_value = mock_exporter + + handler._on_app_deleted(envelope) + + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.APP_DELETED + assert call_args[1]["tenant_id"] == "tenant-123" + attrs = call_args[1]["attributes"] + assert attrs["dify.app_id"] == "app-789" + assert attrs["dify.tenant_id"] == "tenant-123" + assert attrs["dify.event.id"] == "event-456" + assert "dify.app.deleted_at" in attrs + + from enterprise.telemetry.entities import EnterpriseTelemetryCounter + + mock_exporter.increment_counter.assert_called_once() + counter_call = mock_exporter.increment_counter.call_args + assert counter_call[0][0] == EnterpriseTelemetryCounter.APP_DELETED + assert counter_call[0][1] == 1 + assert counter_call[0][2]["tenant_id"] == "tenant-123" + assert counter_call[0][2]["app_id"] == "app-789" + + +def test_on_feedback_created_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.FEEDBACK_CREATED, + tenant_id="tenant-123", + event_id="event-456", + payload={ + "message_id": "msg-001", + "app_id": "app-789", + "conversation_id": "conv-123", + "from_end_user_id": "user-456", + "from_account_id": None, + "rating": "like", + "from_source": "api", + "content": "Great!", + }, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_exporter.include_content = True + mock_get_exporter.return_value = mock_exporter + + handler._on_feedback_created(envelope) + + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == "dify.feedback.created" + assert call_args[1]["attributes"]["dify.message.id"] == "msg-001" + assert call_args[1]["attributes"]["dify.feedback.content"] == "Great!" + assert "dify.feedback.created_at" in call_args[1]["attributes"] + assert call_args[1]["tenant_id"] == "tenant-123" + assert call_args[1]["user_id"] == "user-456" + + mock_exporter.increment_counter.assert_called_once() + counter_args = mock_exporter.increment_counter.call_args + assert counter_args[0][2]["app_id"] == "app-789" + assert counter_args[0][2]["rating"] == "like" + + +def test_on_feedback_created_without_content(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.FEEDBACK_CREATED, + tenant_id="tenant-123", + event_id="event-456", + payload={ + "message_id": "msg-001", + "app_id": "app-789", + "conversation_id": "conv-123", + "from_end_user_id": "user-456", + "from_account_id": None, + "rating": "like", + "from_source": "api", + "content": "Great!", + }, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_exporter.include_content = False + mock_get_exporter.return_value = mock_exporter + + handler._on_feedback_created(envelope) + + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert "dify.feedback.content" not in call_args[1]["attributes"] diff --git a/api/tests/unit_tests/enterprise/telemetry/test_telemetry_log.py b/api/tests/unit_tests/enterprise/telemetry/test_telemetry_log.py new file mode 100644 index 0000000000..0edd0ace27 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_telemetry_log.py @@ -0,0 +1,327 @@ +"""Unit tests for enterprise/telemetry/telemetry_log.py.""" + +from __future__ import annotations + +import uuid +from unittest.mock import MagicMock, patch + +# --------------------------------------------------------------------------- +# compute_trace_id_hex +# --------------------------------------------------------------------------- + + +class TestComputeTraceIdHex: + def setup_method(self) -> None: + # Clear lru_cache between tests to avoid cross-test pollution + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + compute_trace_id_hex.cache_clear() + + def test_none_returns_empty(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + assert compute_trace_id_hex(None) == "" + + def test_empty_string_returns_empty(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + assert compute_trace_id_hex("") == "" + + def test_already_32_hex_chars_returned_as_is(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + hex_id = "a" * 32 + assert compute_trace_id_hex(hex_id) == hex_id + + def test_valid_uuid_string_converted_to_32_hex(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + uid = "123e4567-e89b-12d3-a456-426614174000" + result = compute_trace_id_hex(uid) + assert len(result) == 32 + assert all(ch in "0123456789abcdef" for ch in result) + # Round-trip: int of the UUID should equal the int parsed from result + assert int(result, 16) == uuid.UUID(uid).int + + def test_invalid_string_returns_empty(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + assert compute_trace_id_hex("not-a-uuid") == "" + + def test_whitespace_stripped(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + uid = " 123e4567-e89b-12d3-a456-426614174000 " + result = compute_trace_id_hex(uid) + assert len(result) == 32 + + def test_uppercase_uuid_accepted(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + uid = "123E4567-E89B-12D3-A456-426614174000" + result = compute_trace_id_hex(uid) + assert len(result) == 32 + + def test_result_is_cached(self) -> None: + from enterprise.telemetry.telemetry_log import compute_trace_id_hex + + uid = "123e4567-e89b-12d3-a456-426614174000" + r1 = compute_trace_id_hex(uid) + r2 = compute_trace_id_hex(uid) + assert r1 == r2 + info = compute_trace_id_hex.cache_info() + assert info.hits >= 1 + + +# --------------------------------------------------------------------------- +# compute_span_id_hex +# --------------------------------------------------------------------------- + + +class TestComputeSpanIdHex: + def setup_method(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex + + compute_span_id_hex.cache_clear() + + def test_none_returns_empty(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex + + assert compute_span_id_hex(None) == "" + + def test_empty_string_returns_empty(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex + + assert compute_span_id_hex("") == "" + + def test_already_16_hex_chars_returned_as_is(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex + + hex_id = "abcdef0123456789" + assert compute_span_id_hex(hex_id) == hex_id + + def test_valid_uuid_produces_16_hex_span_id(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex + + uid = "123e4567-e89b-12d3-a456-426614174000" + result = compute_span_id_hex(uid) + assert len(result) == 16 + assert all(ch in "0123456789abcdef" for ch in result) + + def test_invalid_string_returns_empty(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex + + assert compute_span_id_hex("not-a-uuid-at-all!") == "" + + def test_result_is_cached(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex + + uid = "123e4567-e89b-12d3-a456-426614174000" + compute_span_id_hex(uid) + compute_span_id_hex(uid) + info = compute_span_id_hex.cache_info() + assert info.hits >= 1 + + +# --------------------------------------------------------------------------- +# emit_telemetry_log +# --------------------------------------------------------------------------- + + +class TestEmitTelemetryLog: + def setup_method(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex, compute_trace_id_hex + + compute_trace_id_hex.cache_clear() + compute_span_id_hex.cache_clear() + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_logs_info_with_event_name_and_signal(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + + emit_telemetry_log( + event_name="dify.workflow.run", + attributes={"tenant_id": "t1"}, + signal="metric_only", + ) + + mock_logger.info.assert_called_once() + args, kwargs = mock_logger.info.call_args + assert args[0] == "telemetry.%s" + assert args[1] == "metric_only" + extra = kwargs["extra"] + assert extra["attributes"]["dify.event.name"] == "dify.workflow.run" + assert extra["attributes"]["dify.event.signal"] == "metric_only" + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_no_log_when_info_disabled(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = False + + emit_telemetry_log(event_name="dify.workflow.run", attributes={}) + + mock_logger.info.assert_not_called() + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_trace_id_added_to_extra_when_valid_uuid(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + uid = "123e4567-e89b-12d3-a456-426614174000" + + emit_telemetry_log(event_name="test.event", attributes={}, trace_id_source=uid) + + extra = mock_logger.info.call_args.kwargs["extra"] + assert "trace_id" in extra + assert len(extra["trace_id"]) == 32 + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_trace_id_absent_when_invalid_source(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + + emit_telemetry_log(event_name="test.event", attributes={}, trace_id_source="bad-id") + + extra = mock_logger.info.call_args.kwargs["extra"] + assert "trace_id" not in extra + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_span_id_added_to_extra_when_valid_uuid(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + uid = "123e4567-e89b-12d3-a456-426614174000" + + emit_telemetry_log(event_name="test.event", attributes={}, span_id_source=uid) + + extra = mock_logger.info.call_args.kwargs["extra"] + assert "span_id" in extra + assert len(extra["span_id"]) == 16 + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_tenant_id_added_when_provided(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + + emit_telemetry_log(event_name="test.event", attributes={}, tenant_id="tenant-99") + + extra = mock_logger.info.call_args.kwargs["extra"] + assert extra["tenant_id"] == "tenant-99" + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_user_id_added_when_provided(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + + emit_telemetry_log(event_name="test.event", attributes={}, user_id="user-42") + + extra = mock_logger.info.call_args.kwargs["extra"] + assert extra["user_id"] == "user-42" + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_tenant_and_user_id_absent_when_not_provided(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + + emit_telemetry_log(event_name="test.event", attributes={}) + + extra = mock_logger.info.call_args.kwargs["extra"] + assert "tenant_id" not in extra + assert "user_id" not in extra + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_caller_attributes_merged_into_attrs(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + + emit_telemetry_log( + event_name="dify.node.run", + attributes={"node_type": "code", "elapsed": 0.5}, + ) + + extra = mock_logger.info.call_args.kwargs["extra"] + assert extra["attributes"]["node_type"] == "code" + assert extra["attributes"]["elapsed"] == 0.5 + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_signal_span_detail_forwarded(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_telemetry_log + + mock_logger.isEnabledFor.return_value = True + + emit_telemetry_log(event_name="test.event", attributes={}, signal="span_detail") + + args = mock_logger.info.call_args[0] + assert args[1] == "span_detail" + extra = mock_logger.info.call_args.kwargs["extra"] + assert extra["attributes"]["dify.event.signal"] == "span_detail" + + +# --------------------------------------------------------------------------- +# emit_metric_only_event +# --------------------------------------------------------------------------- + + +class TestEmitMetricOnlyEvent: + def setup_method(self) -> None: + from enterprise.telemetry.telemetry_log import compute_span_id_hex, compute_trace_id_hex + + compute_trace_id_hex.cache_clear() + compute_span_id_hex.cache_clear() + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_delegates_to_emit_telemetry_log_with_metric_only_signal(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_metric_only_event + + mock_logger.isEnabledFor.return_value = True + + emit_metric_only_event( + event_name="dify.app.created", + attributes={"app_id": "app-1"}, + tenant_id="t1", + user_id="u1", + ) + + mock_logger.info.assert_called_once() + extra = mock_logger.info.call_args.kwargs["extra"] + assert extra["attributes"]["dify.event.signal"] == "metric_only" + assert extra["attributes"]["dify.event.name"] == "dify.app.created" + assert extra["attributes"]["app_id"] == "app-1" + assert extra["tenant_id"] == "t1" + assert extra["user_id"] == "u1" + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_trace_and_span_ids_passed_through(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_metric_only_event + + mock_logger.isEnabledFor.return_value = True + uid = "123e4567-e89b-12d3-a456-426614174000" + + emit_metric_only_event( + event_name="dify.workflow.run", + attributes={}, + trace_id_source=uid, + span_id_source=uid, + ) + + extra = mock_logger.info.call_args.kwargs["extra"] + assert "trace_id" in extra + assert "span_id" in extra + + @patch("enterprise.telemetry.telemetry_log.logger") + def test_no_log_emitted_when_logger_disabled(self, mock_logger: MagicMock) -> None: + from enterprise.telemetry.telemetry_log import emit_metric_only_event + + mock_logger.isEnabledFor.return_value = False + + emit_metric_only_event(event_name="dify.workflow.run", attributes={}) + + mock_logger.info.assert_not_called() diff --git a/api/tests/unit_tests/events/test_app_event_signals.py b/api/tests/unit_tests/events/test_app_event_signals.py new file mode 100644 index 0000000000..29582a50f6 --- /dev/null +++ b/api/tests/unit_tests/events/test_app_event_signals.py @@ -0,0 +1,206 @@ +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def mock_db(): + with patch("services.app_service.db") as mock_db: + mock_db.session = MagicMock() + yield mock_db + + +@pytest.fixture +def _mock_deps(): + with ( + patch("services.app_service.BillingService"), + patch("services.app_service.FeatureService"), + patch("services.app_service.EnterpriseService"), + patch("services.app_service.remove_app_and_related_data_task"), + ): + yield + + +@pytest.fixture +def app_model(): + app = MagicMock() + app.id = "app-123" + app.tenant_id = "tenant-456" + app.name = "Old Name" + app.icon_type = "emoji" + app.icon = "🤖" + app.icon_background = "#fff" + app.enable_site = False + app.enable_api = False + return app + + +def _make_collector(target: list): + def handler(sender, **kw): + target.append(sender) + + return handler + + +@pytest.mark.usefixtures("mock_db", "_mock_deps") +class TestAppWasDeletedSignal: + def test_sends_signal(self, app_model): + from events.app_event import app_was_deleted + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_deleted.connect(handler) + try: + AppService().delete_app(app_model) + finally: + app_was_deleted.disconnect(handler) + + assert received == [app_model] + + def test_signal_fires_before_db_delete(self, app_model, mock_db): + from events.app_event import app_was_deleted + from services.app_service import AppService + + call_order: list[str] = [] + + def handler(sender, **kw): + call_order.append("signal") + + app_was_deleted.connect(handler) + mock_db.session.delete.side_effect = lambda _: call_order.append("db_delete") + + try: + AppService().delete_app(app_model) + finally: + app_was_deleted.disconnect(handler) + + assert call_order.index("signal") < call_order.index("db_delete") + + +@pytest.mark.usefixtures("mock_db") +class TestAppWasUpdatedSignal: + def test_update_app(self, app_model): + from events.app_event import app_was_updated + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_updated.connect(handler) + + with patch("services.app_service.current_user", MagicMock(id="user-1")): + try: + AppService().update_app( + app_model, + { + "name": "New", + "description": "Desc", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#fff", + "use_icon_as_answer_icon": False, + "max_active_requests": 0, + }, + ) + finally: + app_was_updated.disconnect(handler) + + assert received == [app_model] + + def test_update_app_name(self, app_model): + from events.app_event import app_was_updated + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_updated.connect(handler) + + with patch("services.app_service.current_user", MagicMock(id="user-1")): + try: + AppService().update_app_name(app_model, "New Name") + finally: + app_was_updated.disconnect(handler) + + assert received == [app_model] + + def test_update_app_icon(self, app_model): + from events.app_event import app_was_updated + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_updated.connect(handler) + + with patch("services.app_service.current_user", MagicMock(id="user-1")): + try: + AppService().update_app_icon(app_model, "🎉", "#000") + finally: + app_was_updated.disconnect(handler) + + assert received == [app_model] + + def test_update_app_site_status_sends_when_changed(self, app_model): + from events.app_event import app_was_updated + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_updated.connect(handler) + + with patch("services.app_service.current_user", MagicMock(id="user-1")): + try: + app_model.enable_site = False + AppService().update_app_site_status(app_model, True) + finally: + app_was_updated.disconnect(handler) + + assert received == [app_model] + + def test_update_app_site_status_skips_when_unchanged(self, app_model): + from events.app_event import app_was_updated + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_updated.connect(handler) + + try: + app_model.enable_site = True + AppService().update_app_site_status(app_model, True) + finally: + app_was_updated.disconnect(handler) + + assert received == [] + + def test_update_app_api_status_sends_when_changed(self, app_model): + from events.app_event import app_was_updated + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_updated.connect(handler) + + with patch("services.app_service.current_user", MagicMock(id="user-1")): + try: + app_model.enable_api = False + AppService().update_app_api_status(app_model, True) + finally: + app_was_updated.disconnect(handler) + + assert received == [app_model] + + def test_update_app_api_status_skips_when_unchanged(self, app_model): + from events.app_event import app_was_updated + from services.app_service import AppService + + received = [] + handler = _make_collector(received) + app_was_updated.connect(handler) + + try: + app_model.enable_api = True + AppService().update_app_api_status(app_model, True) + finally: + app_was_updated.disconnect(handler) + + assert received == [] diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py index 511192001e..4fe3f2cb28 100644 --- a/api/tests/unit_tests/factories/test_build_from_mapping.py +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -2,13 +2,13 @@ import uuid from unittest.mock import MagicMock, patch import pytest +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig from httpx import Response from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope from core.workflow.file_reference import build_file_reference, parse_file_reference, resolve_file_record_id from factories.file_factory.builders import build_from_mapping as _build_from_mapping -from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig from models import ToolFile, UploadFile # Test Data diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index 70d7d8c575..8d573b1154 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -4,11 +4,6 @@ from typing import Any from uuid import uuid4 import pytest -from hypothesis import HealthCheck, given, settings -from hypothesis import strategies as st - -from factories import variable_factory -from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type from graphon.file import File, FileTransferMethod, FileType from graphon.variables import ( ArrayNumberVariable, @@ -36,6 +31,11 @@ from graphon.variables.segments import ( StringSegment, ) from graphon.variables.types import SegmentType +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st + +from factories import variable_factory +from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type def test_string_variable(): diff --git a/api/tests/unit_tests/fields/test_file_fields.py b/api/tests/unit_tests/fields/test_file_fields.py index 9d9f626b9e..0e848d6ef5 100644 --- a/api/tests/unit_tests/fields/test_file_fields.py +++ b/api/tests/unit_tests/fields/test_file_fields.py @@ -4,11 +4,11 @@ from datetime import datetime from types import SimpleNamespace import pytest +from graphon.file import File, FileTransferMethod, FileType from core.workflow.file_reference import build_file_reference from fields import conversation_fields, message_fields from fields.file_fields import FileResponse, FileWithSignedUrl, RemoteFileInfo, UploadConfig -from graphon.file import File, FileTransferMethod, FileType def test_file_response_serializes_datetime() -> None: diff --git a/api/tests/unit_tests/graphon/file/test_file_factory.py b/api/tests/unit_tests/graphon/file/test_file_factory.py deleted file mode 100644 index eeb537c28f..0000000000 --- a/api/tests/unit_tests/graphon/file/test_file_factory.py +++ /dev/null @@ -1,18 +0,0 @@ -from graphon.file import FileType -from graphon.file.file_factory import get_file_type_by_mime_type, standardize_file_type - - -def test_standardize_file_type_recognizes_case_insensitive_extension(): - assert standardize_file_type(extension=".PNG") == FileType.IMAGE - - -def test_standardize_file_type_recognizes_document_extension(): - assert standardize_file_type(extension=".txt") == FileType.DOCUMENT - - -def test_standardize_file_type_falls_back_to_mime_type(): - assert standardize_file_type(mime_type="video/mp4") == FileType.VIDEO - - -def test_get_file_type_by_mime_type_returns_custom_for_unknown_type(): - assert get_file_type_by_mime_type("application/octet-stream") == FileType.CUSTOM diff --git a/api/tests/unit_tests/graphon/file/test_file_manager.py b/api/tests/unit_tests/graphon/file/test_file_manager.py deleted file mode 100644 index 1eebb13f4e..0000000000 --- a/api/tests/unit_tests/graphon/file/test_file_manager.py +++ /dev/null @@ -1,133 +0,0 @@ -import base64 -from unittest.mock import MagicMock - -import pytest - -from core.workflow.file_reference import build_file_reference -from graphon.file import File, FileTransferMethod, FileType -from graphon.file.file_manager import download, to_prompt_message_content -from graphon.file.runtime import get_workflow_file_runtime, set_workflow_file_runtime -from graphon.model_runtime.entities import ( - DocumentPromptMessageContent, - ImagePromptMessageContent, - TextPromptMessageContent, -) - - -def _build_file( - *, - transfer_method: FileTransferMethod, - file_type: FileType = FileType.IMAGE, - reference: str | None = None, - remote_url: str | None = None, - filename: str = "image.png", - extension: str = ".png", - mime_type: str = "image/png", -) -> File: - return File( - id="file-id", - type=file_type, - transfer_method=transfer_method, - reference=reference, - remote_url=remote_url, - filename=filename, - extension=extension, - mime_type=mime_type, - size=128, - ) - - -@pytest.fixture -def workflow_file_runtime(): - previous_runtime = get_workflow_file_runtime() - runtime = MagicMock() - set_workflow_file_runtime(runtime) - try: - yield runtime - finally: - set_workflow_file_runtime(previous_runtime) - - -@pytest.mark.parametrize( - "transfer_method", - [ - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.TOOL_FILE, - FileTransferMethod.DATASOURCE_FILE, - ], -) -def test_download_delegates_storage_backed_files_to_runtime_loader(workflow_file_runtime, transfer_method) -> None: - workflow_file_runtime.load_file_bytes.return_value = b"payload" - file = _build_file( - transfer_method=transfer_method, - reference=build_file_reference(record_id="file-id", storage_key="files/payload.bin"), - ) - - assert download(file) == b"payload" - workflow_file_runtime.load_file_bytes.assert_called_once_with(file=file) - - -def test_download_remote_url_uses_runtime_http_get(workflow_file_runtime) -> None: - response = MagicMock() - response.content = b"remote-payload" - workflow_file_runtime.http_get.return_value = response - file = _build_file( - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url="https://example.com/image.png", - ) - - assert download(file) == b"remote-payload" - workflow_file_runtime.http_get.assert_called_once_with("https://example.com/image.png", follow_redirects=True) - response.raise_for_status.assert_called_once_with() - - -def test_to_prompt_message_content_uses_runtime_url_resolution_for_images(workflow_file_runtime) -> None: - workflow_file_runtime.multimodal_send_format = "url" - workflow_file_runtime.resolve_file_url.return_value = "https://cdn.example.com/image.png" - file = _build_file( - transfer_method=FileTransferMethod.LOCAL_FILE, - reference=build_file_reference(record_id="upload-file-id", storage_key="files/image.png"), - ) - - content = to_prompt_message_content(file, image_detail_config=ImagePromptMessageContent.DETAIL.HIGH) - - assert isinstance(content, ImagePromptMessageContent) - assert content.url == "https://cdn.example.com/image.png" - assert content.base64_data == "" - assert content.detail == ImagePromptMessageContent.DETAIL.HIGH - - -def test_to_prompt_message_content_uses_runtime_file_loader_for_base64_documents(workflow_file_runtime) -> None: - workflow_file_runtime.multimodal_send_format = "base64" - workflow_file_runtime.load_file_bytes.return_value = b"document-bytes" - file = _build_file( - transfer_method=FileTransferMethod.TOOL_FILE, - file_type=FileType.DOCUMENT, - reference=build_file_reference(record_id="tool-file-id", storage_key="docs/report.pdf"), - filename="report.pdf", - extension=".pdf", - mime_type="application/pdf", - ) - - content = to_prompt_message_content(file) - - assert isinstance(content, DocumentPromptMessageContent) - assert content.base64_data == base64.b64encode(b"document-bytes").decode("utf-8") - assert content.url == "" - workflow_file_runtime.load_file_bytes.assert_called_once_with(file=file) - - -def test_to_prompt_message_content_returns_text_placeholder_for_custom_files() -> None: - file = _build_file( - transfer_method=FileTransferMethod.REMOTE_URL, - file_type=FileType.CUSTOM, - remote_url="https://example.com/archive.bin", - filename="archive.bin", - extension=".bin", - mime_type="application/octet-stream", - ) - - content = to_prompt_message_content(file) - - assert isinstance(content, TextPromptMessageContent) - assert content.data == "[Unsupported file type: archive.bin (custom)]" diff --git a/api/tests/unit_tests/graphon/file/test_models.py b/api/tests/unit_tests/graphon/file/test_models.py deleted file mode 100644 index 17d244da5f..0000000000 --- a/api/tests/unit_tests/graphon/file/test_models.py +++ /dev/null @@ -1,54 +0,0 @@ -from core.workflow.file_reference import build_file_reference -from graphon.file import File, FileTransferMethod, FileType, helpers - - -def _build_local_file(*, reference: str, storage_key: str | None = None) -> File: - return File( - id="file-id", - type=FileType.DOCUMENT, - transfer_method=FileTransferMethod.LOCAL_FILE, - reference=reference, - filename="report.pdf", - extension=".pdf", - mime_type="application/pdf", - size=128, - storage_key=storage_key, - ) - - -def test_file_exposes_legacy_aliases_from_opaque_reference() -> None: - reference = build_file_reference(record_id="upload-file-id", storage_key="files/report.pdf") - - file = _build_local_file(reference=reference) - - assert file.reference == reference - assert file.related_id == "upload-file-id" - assert file.storage_key == "files/report.pdf" - - -def test_file_falls_back_to_raw_reference_when_opaque_reference_is_invalid() -> None: - file = _build_local_file(reference="dify-file-ref:not-base64", storage_key="fallback-key") - - assert file.related_id == "dify-file-ref:not-base64" - assert file.storage_key == "fallback-key" - - -def test_file_to_dict_keeps_reference_and_legacy_related_id(monkeypatch) -> None: - reference = build_file_reference(record_id="upload-file-id", storage_key="files/report.pdf") - file = _build_local_file(reference=reference) - monkeypatch.setattr(helpers, "resolve_file_url", lambda _file, for_external=True: "https://example.com/report.pdf") - - serialized = file.to_dict() - - assert serialized["reference"] == reference - assert serialized["related_id"] == "upload-file-id" - assert serialized["url"] == "https://example.com/report.pdf" - - -def test_file_related_id_setter_updates_reference_alias() -> None: - file = _build_local_file(reference="upload-file-id", storage_key="files/report.pdf") - - file.related_id = "replacement-upload-id" - - assert file.reference == "replacement-upload-id" - assert file.related_id == "replacement-upload-id" diff --git a/api/tests/unit_tests/graphon/model_runtime/__base/__init__.py b/api/tests/unit_tests/graphon/model_runtime/__base/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/unit_tests/graphon/model_runtime/__base/test_increase_tool_call.py b/api/tests/unit_tests/graphon/model_runtime/__base/test_increase_tool_call.py deleted file mode 100644 index 7b4fc5a04c..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/__base/test_increase_tool_call.py +++ /dev/null @@ -1,114 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage -from graphon.model_runtime.model_providers.__base.large_language_model import _increase_tool_call - -ToolCall = AssistantPromptMessage.ToolCall - -# CASE 1: Single tool call -INPUTS_CASE_1 = [ - ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), -] -EXPECTED_CASE_1 = [ - ToolCall( - id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') - ), -] - -# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...) -INPUTS_CASE_2 = [ - ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), - ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), -] -EXPECTED_CASE_2 = [ - ToolCall( - id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') - ), - ToolCall( - id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}') - ), -] - -# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...) -INPUTS_CASE_3 = [ - ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), - ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), - ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), - ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), - ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), - ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), -] -EXPECTED_CASE_3 = [ - ToolCall( - id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') - ), - ToolCall( - id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}') - ), -] - -# CASE 4: Tool call sequences with no IDs -INPUTS_CASE_4 = [ - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), -] -EXPECTED_CASE_4 = [ - ToolCall( - id="RANDOM_ID_1", - type="function", - function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'), - ), - ToolCall( - id="RANDOM_ID_2", - type="function", - function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'), - ), -] - - -def _run_case(inputs: list[ToolCall], expected: list[ToolCall]): - actual = [] - _increase_tool_call(inputs, actual) - assert actual == expected - - -def test__increase_tool_call(): - # case 1: - _run_case(INPUTS_CASE_1, EXPECTED_CASE_1) - - # case 2: - _run_case(INPUTS_CASE_2, EXPECTED_CASE_2) - - # case 3: - _run_case(INPUTS_CASE_3, EXPECTED_CASE_3) - - # case 4: - mock_id_generator = MagicMock() - mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4] - with patch( - "graphon.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator - ): - _run_case(INPUTS_CASE_4, EXPECTED_CASE_4) - - -def test__increase_tool_call__no_id_no_name_first_delta_should_raise(): - inputs = [ - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), - ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='"value"}')), - ] - actual: list[ToolCall] = [] - with patch("graphon.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", MagicMock()): - with pytest.raises(ValueError): - _increase_tool_call(inputs, actual) diff --git a/api/tests/unit_tests/graphon/model_runtime/__base/test_large_language_model_non_stream_parsing.py b/api/tests/unit_tests/graphon/model_runtime/__base/test_large_language_model_non_stream_parsing.py deleted file mode 100644 index c922fbaa60..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/__base/test_large_language_model_non_stream_parsing.py +++ /dev/null @@ -1,126 +0,0 @@ -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_runtime_result - - -def _make_chunk( - *, - model: str = "test-model", - content: str | list[TextPromptMessageContent] | None, - tool_calls: list[AssistantPromptMessage.ToolCall] | None = None, - usage: LLMUsage | None = None, - system_fingerprint: str | None = None, -) -> LLMResultChunk: - message = AssistantPromptMessage(content=content, tool_calls=tool_calls or []) - delta = LLMResultChunkDelta(index=0, message=message, usage=usage) - return LLMResultChunk(model=model, delta=delta, system_fingerprint=system_fingerprint) - - -def test__normalize_non_stream_runtime_result__from_first_chunk_str_content_and_tool_calls(): - prompt_messages = [UserPromptMessage(content="hi")] - - tool_calls = [ - AssistantPromptMessage.ToolCall( - id="1", - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments=""), - ), - AssistantPromptMessage.ToolCall( - id="", - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='{"arg1": '), - ), - AssistantPromptMessage.ToolCall( - id="", - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='"value"}'), - ), - ] - - usage = LLMUsage.empty_usage().model_copy(update={"prompt_tokens": 1, "total_tokens": 1}) - chunk = _make_chunk(content="hello", tool_calls=tool_calls, usage=usage, system_fingerprint="fp-1") - - result = _normalize_non_stream_runtime_result( - model="test-model", prompt_messages=prompt_messages, result=iter([chunk]) - ) - - assert result.model == "test-model" - assert result.prompt_messages == prompt_messages - assert result.message.content == "hello" - assert result.usage.prompt_tokens == 1 - assert result.system_fingerprint == "fp-1" - assert result.message.tool_calls == [ - AssistantPromptMessage.ToolCall( - id="1", - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'), - ) - ] - - -def test__normalize_non_stream_runtime_result__from_first_chunk_list_content(): - prompt_messages = [UserPromptMessage(content="hi")] - - content_list = [TextPromptMessageContent(data="a"), TextPromptMessageContent(data="b")] - chunk = _make_chunk(content=content_list, usage=LLMUsage.empty_usage()) - - result = _normalize_non_stream_runtime_result( - model="test-model", prompt_messages=prompt_messages, result=iter([chunk]) - ) - - assert result.message.content == content_list - - -def test__normalize_non_stream_runtime_result__passthrough_llm_result(): - prompt_messages = [UserPromptMessage(content="hi")] - llm_result = LLMResult( - model="test-model", - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content="ok"), - usage=LLMUsage.empty_usage(), - ) - - assert ( - _normalize_non_stream_runtime_result(model="test-model", prompt_messages=prompt_messages, result=llm_result) - == llm_result - ) - - -def test__normalize_non_stream_runtime_result__empty_iterator_defaults(): - prompt_messages = [UserPromptMessage(content="hi")] - - result = _normalize_non_stream_runtime_result(model="test-model", prompt_messages=prompt_messages, result=iter([])) - - assert result.model == "test-model" - assert result.prompt_messages == prompt_messages - assert result.message.content == [] - assert result.message.tool_calls == [] - assert result.usage == LLMUsage.empty_usage() - assert result.system_fingerprint is None - - -def test__normalize_non_stream_runtime_result__accumulates_all_chunks(): - """All chunks are accumulated from the iterator.""" - prompt_messages = [UserPromptMessage(content="hi")] - - closed: list[bool] = [] - - def _chunk_iter(): - try: - yield _make_chunk(content="hello", usage=LLMUsage.empty_usage()) - yield _make_chunk(content=" world", usage=LLMUsage.empty_usage()) - finally: - closed.append(True) - - result = _normalize_non_stream_runtime_result( - model="test-model", - prompt_messages=prompt_messages, - result=_chunk_iter(), - ) - - assert result.message.content == "hello world" - assert closed == [True] diff --git a/api/tests/unit_tests/graphon/model_runtime/__init__.py b/api/tests/unit_tests/graphon/model_runtime/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/unit_tests/graphon/model_runtime/callbacks/test_base_callback.py b/api/tests/unit_tests/graphon/model_runtime/callbacks/test_base_callback.py deleted file mode 100644 index 776fc230cb..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/callbacks/test_base_callback.py +++ /dev/null @@ -1,964 +0,0 @@ -"""Comprehensive unit tests for core/model_runtime/callbacks/base_callback.py""" - -from unittest.mock import MagicMock, patch - -import pytest - -from graphon.model_runtime.callbacks.base_callback import ( - _TEXT_COLOR_MAPPING, - Callback, -) -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool - -# --------------------------------------------------------------------------- -# Concrete implementation of the abstract Callback for testing -# --------------------------------------------------------------------------- - - -class ConcreteCallback(Callback): - """A minimal concrete subclass that satisfies all abstract methods.""" - - def __init__(self, raise_error: bool = False): - self.raise_error = raise_error - # Track invocations - self.before_invoke_calls: list[dict] = [] - self.new_chunk_calls: list[dict] = [] - self.after_invoke_calls: list[dict] = [] - self.invoke_error_calls: list[dict] = [] - - def on_before_invoke( - self, - llm_instance, - model, - credentials, - prompt_messages, - model_parameters, - tools=None, - stop=None, - stream=True, - user=None, - ): - self.before_invoke_calls.append( - { - "llm_instance": llm_instance, - "model": model, - "credentials": credentials, - "prompt_messages": prompt_messages, - "model_parameters": model_parameters, - "tools": tools, - "stop": stop, - "stream": stream, - "user": user, - } - ) - # To cover the 'raise NotImplementedError()' in the base class - try: - super().on_before_invoke( - llm_instance, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user - ) - except NotImplementedError: - pass - - def on_new_chunk( - self, - llm_instance, - chunk, - model, - credentials, - prompt_messages, - model_parameters, - tools=None, - stop=None, - stream=True, - user=None, - ): - self.new_chunk_calls.append( - { - "llm_instance": llm_instance, - "chunk": chunk, - "model": model, - "credentials": credentials, - "prompt_messages": prompt_messages, - "model_parameters": model_parameters, - "tools": tools, - "stop": stop, - "stream": stream, - "user": user, - } - ) - try: - super().on_new_chunk( - llm_instance, chunk, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user - ) - except NotImplementedError: - pass - - def on_after_invoke( - self, - llm_instance, - result, - model, - credentials, - prompt_messages, - model_parameters, - tools=None, - stop=None, - stream=True, - user=None, - ): - self.after_invoke_calls.append( - { - "llm_instance": llm_instance, - "result": result, - "model": model, - "credentials": credentials, - "prompt_messages": prompt_messages, - "model_parameters": model_parameters, - "tools": tools, - "stop": stop, - "stream": stream, - "user": user, - } - ) - try: - super().on_after_invoke( - llm_instance, result, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user - ) - except NotImplementedError: - pass - - def on_invoke_error( - self, - llm_instance, - ex, - model, - credentials, - prompt_messages, - model_parameters, - tools=None, - stop=None, - stream=True, - user=None, - ): - self.invoke_error_calls.append( - { - "llm_instance": llm_instance, - "ex": ex, - "model": model, - "credentials": credentials, - "prompt_messages": prompt_messages, - "model_parameters": model_parameters, - "tools": tools, - "stop": stop, - "stream": stream, - "user": user, - } - ) - try: - super().on_invoke_error( - llm_instance, ex, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user - ) - except NotImplementedError: - pass - - -# --------------------------------------------------------------------------- -# A subclass that deliberately leaves abstract methods un-implemented, -# used to verify that instantiation raises TypeError. -# --------------------------------------------------------------------------- - - -# =========================================================================== -# Tests for _TEXT_COLOR_MAPPING module-level constant -# =========================================================================== - - -class TestTextColorMapping: - """Tests for the module-level _TEXT_COLOR_MAPPING dictionary.""" - - def test_contains_all_expected_colors(self): - expected_keys = {"blue", "yellow", "pink", "green", "red"} - assert set(_TEXT_COLOR_MAPPING.keys()) == expected_keys - - def test_blue_escape_code(self): - assert _TEXT_COLOR_MAPPING["blue"] == "36;1" - - def test_yellow_escape_code(self): - assert _TEXT_COLOR_MAPPING["yellow"] == "33;1" - - def test_pink_escape_code(self): - assert _TEXT_COLOR_MAPPING["pink"] == "38;5;200" - - def test_green_escape_code(self): - assert _TEXT_COLOR_MAPPING["green"] == "32;1" - - def test_red_escape_code(self): - assert _TEXT_COLOR_MAPPING["red"] == "31;1" - - def test_mapping_is_dict(self): - assert isinstance(_TEXT_COLOR_MAPPING, dict) - - def test_all_values_are_strings(self): - for key, value in _TEXT_COLOR_MAPPING.items(): - assert isinstance(value, str), f"Value for {key!r} should be str" - - -# =========================================================================== -# Tests for the Callback ABC itself -# =========================================================================== - - -class TestCallbackAbstract: - """Tests verifying Callback is a proper ABC.""" - - def test_cannot_instantiate_abstract_class_directly(self): - """Callback cannot be instantiated since it has abstract methods.""" - with pytest.raises(TypeError): - Callback() # type: ignore[abstract] - - def test_concrete_subclass_can_be_instantiated(self): - cb = ConcreteCallback() - assert isinstance(cb, Callback) - - def test_default_raise_error_is_false(self): - cb = ConcreteCallback() - assert cb.raise_error is False - - def test_raise_error_can_be_set_to_true(self): - cb = ConcreteCallback(raise_error=True) - assert cb.raise_error is True - - def test_subclass_missing_on_before_invoke_raises_type_error(self): - """A subclass missing any single abstract method cannot be instantiated.""" - - class IncompleteCallback(Callback): - def on_new_chunk(self, *a, **kw): ... - def on_after_invoke(self, *a, **kw): ... - def on_invoke_error(self, *a, **kw): ... - - with pytest.raises(TypeError): - IncompleteCallback() # type: ignore[abstract] - - def test_subclass_missing_on_new_chunk_raises_type_error(self): - class IncompleteCallback(Callback): - def on_before_invoke(self, *a, **kw): ... - def on_after_invoke(self, *a, **kw): ... - def on_invoke_error(self, *a, **kw): ... - - with pytest.raises(TypeError): - IncompleteCallback() # type: ignore[abstract] - - def test_subclass_missing_on_after_invoke_raises_type_error(self): - class IncompleteCallback(Callback): - def on_before_invoke(self, *a, **kw): ... - def on_new_chunk(self, *a, **kw): ... - def on_invoke_error(self, *a, **kw): ... - - with pytest.raises(TypeError): - IncompleteCallback() # type: ignore[abstract] - - def test_subclass_missing_on_invoke_error_raises_type_error(self): - class IncompleteCallback(Callback): - def on_before_invoke(self, *a, **kw): ... - def on_new_chunk(self, *a, **kw): ... - def on_after_invoke(self, *a, **kw): ... - - with pytest.raises(TypeError): - IncompleteCallback() # type: ignore[abstract] - - -# =========================================================================== -# Tests for on_before_invoke -# =========================================================================== - - -class TestOnBeforeInvoke: - """Tests for the on_before_invoke callback method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - self.llm_instance = MagicMock() - self.model = "gpt-4" - self.credentials = {"api_key": "sk-test"} - self.prompt_messages = [MagicMock(spec=PromptMessage)] - self.model_parameters = {"temperature": 0.7} - - def test_on_before_invoke_called_with_required_args(self): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.before_invoke_calls) == 1 - call = self.cb.before_invoke_calls[0] - assert call["llm_instance"] is self.llm_instance - assert call["model"] == self.model - assert call["credentials"] == self.credentials - assert call["prompt_messages"] is self.prompt_messages - assert call["model_parameters"] is self.model_parameters - - def test_on_before_invoke_defaults_tools_none(self): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.before_invoke_calls[0]["tools"] is None - - def test_on_before_invoke_defaults_stop_none(self): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.before_invoke_calls[0]["stop"] is None - - def test_on_before_invoke_defaults_stream_true(self): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.before_invoke_calls[0]["stream"] is True - - def test_on_before_invoke_defaults_user_none(self): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.before_invoke_calls[0]["user"] is None - - def test_on_before_invoke_with_all_optional_args(self): - tools = [MagicMock(spec=PromptMessageTool)] - stop = ["stop1", "stop2"] - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - tools=tools, - stop=stop, - stream=False, - user="user-123", - ) - call = self.cb.before_invoke_calls[0] - assert call["tools"] is tools - assert call["stop"] == stop - assert call["stream"] is False - assert call["user"] == "user-123" - - def test_on_before_invoke_called_multiple_times(self): - for i in range(3): - self.cb.on_before_invoke( - llm_instance=self.llm_instance, - model=f"model-{i}", - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.before_invoke_calls) == 3 - assert self.cb.before_invoke_calls[2]["model"] == "model-2" - - -# =========================================================================== -# Tests for on_new_chunk -# =========================================================================== - - -class TestOnNewChunk: - """Tests for the on_new_chunk callback method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - self.llm_instance = MagicMock() - self.chunk = MagicMock(spec=LLMResultChunk) - self.model = "gpt-3.5-turbo" - self.credentials = {"api_key": "sk-test"} - self.prompt_messages = [MagicMock(spec=PromptMessage)] - self.model_parameters = {"max_tokens": 256} - - def test_on_new_chunk_called_with_required_args(self): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.new_chunk_calls) == 1 - call = self.cb.new_chunk_calls[0] - assert call["llm_instance"] is self.llm_instance - assert call["chunk"] is self.chunk - assert call["model"] == self.model - assert call["credentials"] == self.credentials - - def test_on_new_chunk_defaults_tools_none(self): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.new_chunk_calls[0]["tools"] is None - - def test_on_new_chunk_defaults_stop_none(self): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.new_chunk_calls[0]["stop"] is None - - def test_on_new_chunk_defaults_stream_true(self): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.new_chunk_calls[0]["stream"] is True - - def test_on_new_chunk_defaults_user_none(self): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.new_chunk_calls[0]["user"] is None - - def test_on_new_chunk_with_all_optional_args(self): - tools = [MagicMock(spec=PromptMessageTool)] - stop = ["END"] - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - tools=tools, - stop=stop, - stream=False, - user="chunk-user", - ) - call = self.cb.new_chunk_calls[0] - assert call["tools"] is tools - assert call["stop"] == stop - assert call["stream"] is False - assert call["user"] == "chunk-user" - - def test_on_new_chunk_called_multiple_times(self): - for i in range(5): - self.cb.on_new_chunk( - llm_instance=self.llm_instance, - chunk=self.chunk, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.new_chunk_calls) == 5 - - -# =========================================================================== -# Tests for on_after_invoke -# =========================================================================== - - -class TestOnAfterInvoke: - """Tests for the on_after_invoke callback method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - self.llm_instance = MagicMock() - self.result = MagicMock(spec=LLMResult) - self.model = "claude-3" - self.credentials = {"api_key": "anthropic-key"} - self.prompt_messages = [MagicMock(spec=PromptMessage)] - self.model_parameters = {"temperature": 1.0} - - def test_on_after_invoke_called_with_required_args(self): - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.after_invoke_calls) == 1 - call = self.cb.after_invoke_calls[0] - assert call["llm_instance"] is self.llm_instance - assert call["result"] is self.result - assert call["model"] == self.model - assert call["credentials"] is self.credentials - - def test_on_after_invoke_defaults_tools_none(self): - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.after_invoke_calls[0]["tools"] is None - - def test_on_after_invoke_defaults_stop_none(self): - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.after_invoke_calls[0]["stop"] is None - - def test_on_after_invoke_defaults_stream_true(self): - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.after_invoke_calls[0]["stream"] is True - - def test_on_after_invoke_defaults_user_none(self): - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.after_invoke_calls[0]["user"] is None - - def test_on_after_invoke_with_all_optional_args(self): - tools = [MagicMock(spec=PromptMessageTool)] - stop = ["STOP"] - self.cb.on_after_invoke( - llm_instance=self.llm_instance, - result=self.result, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - tools=tools, - stop=stop, - stream=False, - user="after-user", - ) - call = self.cb.after_invoke_calls[0] - assert call["tools"] is tools - assert call["stop"] == stop - assert call["stream"] is False - assert call["user"] == "after-user" - - -# =========================================================================== -# Tests for on_invoke_error -# =========================================================================== - - -class TestOnInvokeError: - """Tests for the on_invoke_error callback method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - self.llm_instance = MagicMock() - self.ex = ValueError("something went wrong") - self.model = "gemini-pro" - self.credentials = {"api_key": "google-key"} - self.prompt_messages = [MagicMock(spec=PromptMessage)] - self.model_parameters = {"top_p": 0.9} - - def test_on_invoke_error_called_with_required_args(self): - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.invoke_error_calls) == 1 - call = self.cb.invoke_error_calls[0] - assert call["llm_instance"] is self.llm_instance - assert call["ex"] is self.ex - assert call["model"] == self.model - assert call["credentials"] is self.credentials - - def test_on_invoke_error_defaults_tools_none(self): - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.invoke_error_calls[0]["tools"] is None - - def test_on_invoke_error_defaults_stop_none(self): - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.invoke_error_calls[0]["stop"] is None - - def test_on_invoke_error_defaults_stream_true(self): - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.invoke_error_calls[0]["stream"] is True - - def test_on_invoke_error_defaults_user_none(self): - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert self.cb.invoke_error_calls[0]["user"] is None - - def test_on_invoke_error_with_all_optional_args(self): - tools = [MagicMock(spec=PromptMessageTool)] - stop = ["HALT"] - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=self.ex, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - tools=tools, - stop=stop, - stream=False, - user="error-user", - ) - call = self.cb.invoke_error_calls[0] - assert call["tools"] is tools - assert call["stop"] == stop - assert call["stream"] is False - assert call["user"] == "error-user" - - def test_on_invoke_error_accepts_various_exception_types(self): - for exc in [RuntimeError("r"), KeyError("k"), Exception("e")]: - self.cb.on_invoke_error( - llm_instance=self.llm_instance, - ex=exc, - model=self.model, - credentials=self.credentials, - prompt_messages=self.prompt_messages, - model_parameters=self.model_parameters, - ) - assert len(self.cb.invoke_error_calls) == 3 - - -# =========================================================================== -# Tests for print_text (concrete method on Callback) -# =========================================================================== - - -class TestPrintText: - """Tests for the concrete print_text method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - - def test_print_text_without_color_prints_plain_text(self, capsys): - self.cb.print_text("hello world") - captured = capsys.readouterr() - assert captured.out == "hello world" - - def test_print_text_with_color_prints_colored_text(self, capsys): - self.cb.print_text("colored text", color="blue") - captured = capsys.readouterr() - # Should contain ANSI escape sequences - assert "colored text" in captured.out - assert "\001b[" in captured.out or "\033[" in captured.out or "\x1b[" in captured.out - - def test_print_text_without_color_no_ansi(self, capsys): - self.cb.print_text("plain text", color=None) - captured = capsys.readouterr() - assert captured.out == "plain text" - # No ANSI escape sequences - assert "\x1b" not in captured.out - - def test_print_text_default_end_is_empty_string(self, capsys): - self.cb.print_text("no newline") - captured = capsys.readouterr() - assert not captured.out.endswith("\n") - - def test_print_text_with_custom_end(self, capsys): - self.cb.print_text("with newline", end="\n") - captured = capsys.readouterr() - assert captured.out.endswith("\n") - - def test_print_text_with_empty_string(self, capsys): - self.cb.print_text("", color=None) - captured = capsys.readouterr() - assert captured.out == "" - - @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) - def test_print_text_all_colors_work(self, color, capsys): - """Verify no KeyError is thrown for any valid color.""" - self.cb.print_text("test", color=color) - captured = capsys.readouterr() - assert "test" in captured.out - - def test_print_text_calls_get_colored_text_when_color_given(self): - with patch.object(self.cb, "_get_colored_text", return_value="[COLORED]") as mock_gct: - with patch("builtins.print") as mock_print: - self.cb.print_text("hello", color="green") - mock_gct.assert_called_once_with("hello", "green") - mock_print.assert_called_once_with("[COLORED]", end="") - - def test_print_text_does_not_call_get_colored_text_when_no_color(self): - with patch.object(self.cb, "_get_colored_text") as mock_gct: - with patch("builtins.print"): - self.cb.print_text("hello", color=None) - mock_gct.assert_not_called() - - def test_print_text_passes_end_to_print(self): - with patch("builtins.print") as mock_print: - self.cb.print_text("text", end="---") - mock_print.assert_called_once_with("text", end="---") - - -# =========================================================================== -# Tests for _get_colored_text (private helper method) -# =========================================================================== - - -class TestGetColoredText: - """Tests for the _get_colored_text private method.""" - - def setup_method(self): - self.cb = ConcreteCallback() - - @pytest.mark.parametrize(("color", "expected_code"), list(_TEXT_COLOR_MAPPING.items())) - def test_get_colored_text_uses_correct_escape_code(self, color, expected_code): - result = self.cb._get_colored_text("text", color) - assert expected_code in result - - @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) - def test_get_colored_text_contains_input_text(self, color): - result = self.cb._get_colored_text("hello", color) - assert "hello" in result - - @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) - def test_get_colored_text_starts_with_escape(self, color): - result = self.cb._get_colored_text("text", color) - # Should start with an ANSI escape (\x1b or \u001b) - assert result.startswith("\x1b[") or result.startswith("\u001b[") - - @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) - def test_get_colored_text_ends_with_reset(self, color): - result = self.cb._get_colored_text("text", color) - # Should end with the ANSI reset code - assert result.endswith("\x1b[0m") or result.endswith("\u001b[0m") - - def test_get_colored_text_returns_string(self): - result = self.cb._get_colored_text("text", "blue") - assert isinstance(result, str) - - def test_get_colored_text_blue_exact_format(self): - result = self.cb._get_colored_text("hello", "blue") - expected = f"\u001b[{_TEXT_COLOR_MAPPING['blue']}m\033[1;3mhello\u001b[0m" - assert result == expected - - def test_get_colored_text_red_exact_format(self): - result = self.cb._get_colored_text("error", "red") - expected = f"\u001b[{_TEXT_COLOR_MAPPING['red']}m\033[1;3merror\u001b[0m" - assert result == expected - - def test_get_colored_text_green_exact_format(self): - result = self.cb._get_colored_text("ok", "green") - expected = f"\u001b[{_TEXT_COLOR_MAPPING['green']}m\033[1;3mok\u001b[0m" - assert result == expected - - def test_get_colored_text_yellow_exact_format(self): - result = self.cb._get_colored_text("warn", "yellow") - expected = f"\u001b[{_TEXT_COLOR_MAPPING['yellow']}m\033[1;3mwarn\u001b[0m" - assert result == expected - - def test_get_colored_text_pink_exact_format(self): - result = self.cb._get_colored_text("info", "pink") - expected = f"\u001b[{_TEXT_COLOR_MAPPING['pink']}m\033[1;3minfo\u001b[0m" - assert result == expected - - def test_get_colored_text_empty_string(self): - result = self.cb._get_colored_text("", "blue") - assert isinstance(result, str) - # Empty text should still have escape codes - assert _TEXT_COLOR_MAPPING["blue"] in result - - def test_get_colored_text_invalid_color_raises_key_error(self): - with pytest.raises(KeyError): - self.cb._get_colored_text("text", "purple") - - def test_get_colored_text_with_special_characters(self): - special = "hello\nworld\ttab" - result = self.cb._get_colored_text(special, "blue") - assert special in result - - def test_get_colored_text_with_long_text(self): - long_text = "a" * 10000 - result = self.cb._get_colored_text(long_text, "green") - assert long_text in result - - -# =========================================================================== -# Integration-style tests: full workflow through a ConcreteCallback -# =========================================================================== - - -class TestConcreteCallbackIntegration: - """End-to-end workflow tests using ConcreteCallback.""" - - def test_full_invocation_lifecycle(self): - """Simulate a complete LLM invocation lifecycle through all callbacks.""" - cb = ConcreteCallback() - llm_instance = MagicMock() - model = "gpt-4o" - credentials = {"api_key": "sk-xyz"} - prompt_messages = [MagicMock(spec=PromptMessage)] - model_parameters = {"temperature": 0.5} - tools = [MagicMock(spec=PromptMessageTool)] - stop = [""] - user = "user-abc" - - # 1. Before invoke - cb.on_before_invoke( - llm_instance=llm_instance, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=True, - user=user, - ) - - # 2. Multiple chunks during streaming - for i in range(3): - chunk = MagicMock(spec=LLMResultChunk) - cb.on_new_chunk( - llm_instance=llm_instance, - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=True, - user=user, - ) - - # 3. After invoke - result = MagicMock(spec=LLMResult) - cb.on_after_invoke( - llm_instance=llm_instance, - result=result, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=True, - user=user, - ) - - assert len(cb.before_invoke_calls) == 1 - assert len(cb.new_chunk_calls) == 3 - assert len(cb.after_invoke_calls) == 1 - assert len(cb.invoke_error_calls) == 0 - - def test_error_lifecycle(self): - """Simulate an invoke that results in an error.""" - cb = ConcreteCallback() - llm_instance = MagicMock() - model = "gpt-4" - credentials = {} - prompt_messages = [] - model_parameters = {} - - cb.on_before_invoke( - llm_instance=llm_instance, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - ) - - ex = RuntimeError("API timeout") - cb.on_invoke_error( - llm_instance=llm_instance, - ex=ex, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - ) - - assert len(cb.before_invoke_calls) == 1 - assert len(cb.invoke_error_calls) == 1 - assert cb.invoke_error_calls[0]["ex"] is ex - assert len(cb.after_invoke_calls) == 0 - - def test_print_text_with_color_in_integration(self, capsys): - """verify print_text works correctly in a concrete instance.""" - cb = ConcreteCallback() - cb.print_text("SUCCESS", color="green", end="\n") - captured = capsys.readouterr() - assert "SUCCESS" in captured.out - assert "\n" in captured.out - - def test_print_text_no_color_in_integration(self, capsys): - cb = ConcreteCallback() - cb.print_text("plain output") - captured = capsys.readouterr() - assert captured.out == "plain output" diff --git a/api/tests/unit_tests/graphon/model_runtime/callbacks/test_logging_callback.py b/api/tests/unit_tests/graphon/model_runtime/callbacks/test_logging_callback.py deleted file mode 100644 index df9215826c..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/callbacks/test_logging_callback.py +++ /dev/null @@ -1,700 +0,0 @@ -""" -Comprehensive unit tests for core/model_runtime/callbacks/logging_callback.py - -Coverage targets: - - LoggingCallback.on_before_invoke (all branches: stop, tools, user, stream, - prompt_message.name, model_parameters) - - LoggingCallback.on_new_chunk (writes to stdout) - - LoggingCallback.on_after_invoke (all branches: tool_calls present / absent) - - LoggingCallback.on_invoke_error (logs exception via logger.exception) -""" - -from __future__ import annotations - -import json -from collections.abc import Sequence -from decimal import Decimal -from unittest.mock import MagicMock, patch - -import pytest - -from graphon.model_runtime.callbacks.logging_callback import LoggingCallback -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMUsage, -) -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageTool, - SystemPromptMessage, - UserPromptMessage, -) - -# --------------------------------------------------------------------------- -# Shared helpers -# --------------------------------------------------------------------------- - - -def _make_usage() -> LLMUsage: - """Return a minimal LLMUsage instance.""" - return LLMUsage( - prompt_tokens=10, - prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal("0.001"), - prompt_price=Decimal("0.01"), - completion_tokens=20, - completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal("0.002"), - completion_price=Decimal("0.04"), - total_tokens=30, - total_price=Decimal("0.05"), - currency="USD", - latency=0.5, - ) - - -def _make_llm_result( - content: str = "hello world", - tool_calls: list | None = None, - model: str = "gpt-4", - system_fingerprint: str | None = "fp-abc", -) -> LLMResult: - """Return an LLMResult with an AssistantPromptMessage.""" - assistant_msg = AssistantPromptMessage( - content=content, - tool_calls=tool_calls or [], - ) - return LLMResult( - model=model, - message=assistant_msg, - usage=_make_usage(), - system_fingerprint=system_fingerprint, - ) - - -def _make_chunk(content: str = "chunk-text") -> LLMResultChunk: - """Return a minimal LLMResultChunk.""" - return LLMResultChunk( - model="gpt-4", - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=content), - ), - ) - - -def _make_user_prompt(content: str = "Hello!", name: str | None = None) -> UserPromptMessage: - return UserPromptMessage(content=content, name=name) - - -def _make_system_prompt(content: str = "You are helpful.") -> SystemPromptMessage: - return SystemPromptMessage(content=content) - - -def _make_tool(name: str = "my_tool") -> PromptMessageTool: - return PromptMessageTool(name=name, description="A tool", parameters={}) - - -def _make_tool_call( - call_id: str = "call-1", - func_name: str = "some_func", - arguments: str = '{"key": "value"}', -) -> AssistantPromptMessage.ToolCall: - return AssistantPromptMessage.ToolCall( - id=call_id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=func_name, arguments=arguments), - ) - - -# --------------------------------------------------------------------------- -# Fixture: shared LoggingCallback instance (no heavy state) -# --------------------------------------------------------------------------- - - -@pytest.fixture -def cb() -> LoggingCallback: - return LoggingCallback() - - -@pytest.fixture -def llm_instance() -> MagicMock: - return MagicMock() - - -# =========================================================================== -# Tests for on_before_invoke -# =========================================================================== - - -class TestOnBeforeInvoke: - """Tests for LoggingCallback.on_before_invoke.""" - - def _invoke( - self, - cb: LoggingCallback, - llm_instance: MagicMock, - *, - model: str = "gpt-4", - credentials: dict | None = None, - prompt_messages: list | None = None, - model_parameters: dict | None = None, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - cb.on_before_invoke( - llm_instance=llm_instance, - model=model, - credentials=credentials or {}, - prompt_messages=prompt_messages or [], - model_parameters=model_parameters or {}, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - - def test_minimal_call_does_not_raise(self, cb: LoggingCallback, llm_instance: MagicMock): - """Calling with bare-minimum args should not raise.""" - self._invoke(cb, llm_instance) - - def test_model_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """The model name must appear in print_text calls.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, model="claude-3") - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "claude-3" in calls_text - - def test_model_parameters_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """Each key-value pair of model_parameters must be printed.""" - params = {"temperature": 0.7, "max_tokens": 512} - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, model_parameters=params) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "temperature" in calls_text - assert "0.7" in calls_text - assert "max_tokens" in calls_text - assert "512" in calls_text - - def test_empty_model_parameters(self, cb: LoggingCallback, llm_instance: MagicMock): - """Empty model_parameters dict should not raise.""" - self._invoke(cb, llm_instance, model_parameters={}) - - # ------------------------------------------------------------------ - # stop branch - # ------------------------------------------------------------------ - - def test_stop_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): - """stop words must appear in output when provided.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, stop=["STOP", "END"]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "stop" in calls_text - - def test_stop_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): - """When stop=None the stop line must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, stop=None) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "\tstop:" not in calls_text - - def test_stop_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock): - """When stop=[] (falsy) the stop line must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, stop=[]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "\tstop:" not in calls_text - - # ------------------------------------------------------------------ - # tools branch - # ------------------------------------------------------------------ - - def test_tools_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): - """Tool names must appear in output when tools are provided.""" - tools = [_make_tool("search"), _make_tool("calculate")] - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, tools=tools) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "search" in calls_text - assert "calculate" in calls_text - - def test_tools_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): - """When tools=None the Tools section must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, tools=None) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "Tools:" not in calls_text - - def test_tools_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock): - """When tools=[] (falsy) the Tools section must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, tools=[]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "Tools:" not in calls_text - - # ------------------------------------------------------------------ - # user branch - # ------------------------------------------------------------------ - - def test_user_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): - """User string must appear in output when provided.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, user="alice") - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "alice" in calls_text - - def test_user_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): - """When user=None the User line must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, user=None) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "User:" not in calls_text - - # ------------------------------------------------------------------ - # stream branch - # ------------------------------------------------------------------ - - def test_stream_true_prints_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock): - """When stream=True the [on_llm_new_chunk] marker must be printed.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, stream=True) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "[on_llm_new_chunk]" in calls_text - - def test_stream_false_no_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock): - """When stream=False the [on_llm_new_chunk] marker must NOT appear.""" - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, stream=False) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "[on_llm_new_chunk]" not in calls_text - - # ------------------------------------------------------------------ - # prompt_messages branch - # ------------------------------------------------------------------ - - def test_prompt_message_with_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """When a PromptMessage has a name it must be printed.""" - msg = _make_user_prompt("hi", name="bob") - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, prompt_messages=[msg]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "bob" in calls_text - - def test_prompt_message_without_name_skips_name_line(self, cb: LoggingCallback, llm_instance: MagicMock): - """When a PromptMessage has no name the name line must NOT appear.""" - msg = _make_user_prompt("hi", name=None) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, prompt_messages=[msg]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "\tname:" not in calls_text - - def test_prompt_message_role_and_content_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """Role and content of each PromptMessage must appear in output.""" - msg = _make_system_prompt("Be concise.") - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, prompt_messages=[msg]) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "system" in calls_text - assert "Be concise." in calls_text - - def test_multiple_prompt_messages_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """All entries in prompt_messages are iterated and printed.""" - msgs = [ - _make_system_prompt("sys"), - _make_user_prompt("user msg"), - ] - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, prompt_messages=msgs) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "sys" in calls_text - assert "user msg" in calls_text - - # ------------------------------------------------------------------ - # Combination: everything provided - # ------------------------------------------------------------------ - - def test_all_optional_fields_combined(self, cb: LoggingCallback, llm_instance: MagicMock): - """Supply stop, tools, user, multiple params, named message – no exception.""" - msgs = [_make_user_prompt("question", name="alice")] - tools = [_make_tool("tool_a")] - with patch.object(cb, "print_text"): - self._invoke( - cb, - llm_instance, - model="gpt-3.5", - model_parameters={"temperature": 1.0, "top_p": 0.9}, - tools=tools, - stop=["DONE"], - stream=True, - user="alice", - prompt_messages=msgs, - ) - - -# =========================================================================== -# Tests for on_new_chunk -# =========================================================================== - - -class TestOnNewChunk: - """Tests for LoggingCallback.on_new_chunk.""" - - def test_chunk_content_written_to_stdout(self, cb: LoggingCallback, llm_instance: MagicMock): - """on_new_chunk must write the chunk's text content to sys.stdout.""" - chunk = _make_chunk("hello from LLM") - written = [] - - with patch("sys.stdout") as mock_stdout: - mock_stdout.write.side_effect = written.append - cb.on_new_chunk( - llm_instance=llm_instance, - chunk=chunk, - model="gpt-4", - credentials={}, - prompt_messages=[], - model_parameters={}, - ) - mock_stdout.write.assert_called_once_with("hello from LLM") - mock_stdout.flush.assert_called_once() - - def test_chunk_content_empty_string(self, cb: LoggingCallback, llm_instance: MagicMock): - """Works correctly even when the chunk content is an empty string.""" - chunk = _make_chunk("") - with patch("sys.stdout") as mock_stdout: - cb.on_new_chunk( - llm_instance=llm_instance, - chunk=chunk, - model="gpt-4", - credentials={}, - prompt_messages=[], - model_parameters={}, - ) - mock_stdout.write.assert_called_once_with("") - mock_stdout.flush.assert_called_once() - - def test_chunk_passes_all_optional_params(self, cb: LoggingCallback, llm_instance: MagicMock): - """All optional parameters are accepted without errors.""" - chunk = _make_chunk("data") - with patch("sys.stdout"): - cb.on_new_chunk( - llm_instance=llm_instance, - chunk=chunk, - model="gpt-4", - credentials={"key": "secret"}, - prompt_messages=[_make_user_prompt("q")], - model_parameters={"temperature": 0.5}, - tools=[_make_tool("t1")], - stop=["EOS"], - stream=True, - user="bob", - ) - - -# =========================================================================== -# Tests for on_after_invoke -# =========================================================================== - - -class TestOnAfterInvoke: - """Tests for LoggingCallback.on_after_invoke.""" - - def _invoke( - self, - cb: LoggingCallback, - llm_instance: MagicMock, - result: LLMResult, - **kwargs, - ): - cb.on_after_invoke( - llm_instance=llm_instance, - result=result, - model=result.model, - credentials={}, - prompt_messages=[], - model_parameters={}, - **kwargs, - ) - - def test_basic_result_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """After-invoke header, content, model, usage, fingerprint must be printed.""" - result = _make_llm_result() - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "[on_llm_after_invoke]" in calls_text - assert "hello world" in calls_text - assert "gpt-4" in calls_text - assert "fp-abc" in calls_text - - def test_no_tool_calls_skips_tool_call_block(self, cb: LoggingCallback, llm_instance: MagicMock): - """When there are no tool_calls the 'Tool calls:' block must NOT appear.""" - result = _make_llm_result(tool_calls=[]) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "Tool calls:" not in calls_text - - def test_with_tool_calls_prints_all_fields(self, cb: LoggingCallback, llm_instance: MagicMock): - """When tool_calls exist their id, name, and JSON arguments must be printed.""" - tc = _make_tool_call( - call_id="call-xyz", - func_name="fetch_data", - arguments='{"url": "https://example.com"}', - ) - result = _make_llm_result(tool_calls=[tc]) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "Tool calls:" in calls_text - assert "call-xyz" in calls_text - assert "fetch_data" in calls_text - # arguments should be JSON-dumped - assert "https://example.com" in calls_text - - def test_multiple_tool_calls_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """All tool calls in the list must be iterated.""" - tcs = [ - _make_tool_call("id-1", "func_a", '{"a": 1}'), - _make_tool_call("id-2", "func_b", '{"b": 2}'), - ] - result = _make_llm_result(tool_calls=tcs) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "id-1" in calls_text - assert "func_a" in calls_text - assert "id-2" in calls_text - assert "func_b" in calls_text - - def test_system_fingerprint_none_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """When system_fingerprint is None it should still be printed (as None).""" - result = _make_llm_result(system_fingerprint=None) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "System Fingerprint: None" in calls_text - - def test_usage_printed(self, cb: LoggingCallback, llm_instance: MagicMock): - """The usage object must appear in the printed output.""" - result = _make_llm_result() - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "Usage:" in calls_text - - def test_tool_call_arguments_are_json_dumped(self, cb: LoggingCallback, llm_instance: MagicMock): - """Verify json.dumps is applied to the arguments field (a string).""" - raw_args = '{"x": 42}' - tc = _make_tool_call(arguments=raw_args) - result = _make_llm_result(tool_calls=[tc]) - with patch.object(cb, "print_text") as mock_print: - self._invoke(cb, llm_instance, result) - - # Check if any call to print_text included the expected (json-encoded) arguments - # json.dumps(raw_args) produces a string starting and ending with quotes - expected_substring = json.dumps(raw_args) - found = any(expected_substring in str(call.args[0]) for call in mock_print.call_args_list) - assert found, f"Expected {expected_substring} to be printed in one of the calls" - - def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock): - """All optional parameters should be accepted without error.""" - result = _make_llm_result() - cb.on_after_invoke( - llm_instance=llm_instance, - result=result, - model=result.model, - credentials={"key": "secret"}, - prompt_messages=[_make_user_prompt("q")], - model_parameters={"temperature": 0.9}, - tools=[_make_tool("t")], - stop=[""], - stream=False, - user="carol", - ) - - -# =========================================================================== -# Tests for on_invoke_error -# =========================================================================== - - -class TestOnInvokeError: - """Tests for LoggingCallback.on_invoke_error.""" - - def _invoke_error( - self, - cb: LoggingCallback, - llm_instance: MagicMock, - ex: Exception, - **kwargs, - ): - cb.on_invoke_error( - llm_instance=llm_instance, - ex=ex, - model="gpt-4", - credentials={}, - prompt_messages=[], - model_parameters={}, - **kwargs, - ) - - def test_prints_error_header(self, cb: LoggingCallback, llm_instance: MagicMock): - """The [on_llm_invoke_error] banner must be printed.""" - with patch.object(cb, "print_text") as mock_print: - with patch("graphon.model_runtime.callbacks.logging_callback.logger") as mock_logger: - self._invoke_error(cb, llm_instance, RuntimeError("boom")) - calls_text = " ".join(str(c) for c in mock_print.call_args_list) - assert "[on_llm_invoke_error]" in calls_text - - def test_exception_logged_via_logger_exception(self, cb: LoggingCallback, llm_instance: MagicMock): - """logger.exception must be called with the exception.""" - ex = ValueError("something went wrong") - with patch.object(cb, "print_text"): - with patch("graphon.model_runtime.callbacks.logging_callback.logger") as mock_logger: - self._invoke_error(cb, llm_instance, ex) - mock_logger.exception.assert_called_once_with(ex) - - def test_exception_type_variety(self, cb: LoggingCallback, llm_instance: MagicMock): - """Works with any exception type (TypeError, IOError, etc.).""" - for exc_cls in (TypeError, IOError, KeyError, Exception): - ex = exc_cls("error") - with patch.object(cb, "print_text"): - with patch("graphon.model_runtime.callbacks.logging_callback.logger") as mock_logger: - self._invoke_error(cb, llm_instance, ex) - mock_logger.exception.assert_called_once_with(ex) - - def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock): - """All optional parameters should be accepted without error.""" - ex = RuntimeError("fail") - with patch.object(cb, "print_text"): - with patch("graphon.model_runtime.callbacks.logging_callback.logger"): - cb.on_invoke_error( - llm_instance=llm_instance, - ex=ex, - model="gpt-4", - credentials={"key": "secret"}, - prompt_messages=[_make_user_prompt("q")], - model_parameters={"temperature": 0.7}, - tools=[_make_tool("t")], - stop=["STOP"], - stream=True, - user="dave", - ) - - -# =========================================================================== -# Tests for print_text (inherited from Callback, exercised through LoggingCallback) -# =========================================================================== - - -class TestPrintText: - """Verify that print_text from the Callback base class works correctly.""" - - def test_print_text_with_color(self, cb: LoggingCallback, capsys): - """print_text with a known colour should emit an ANSI escape sequence.""" - cb.print_text("hello", color="blue") - captured = capsys.readouterr() - assert "hello" in captured.out - # ANSI escape codes should be present - assert "\x1b[" in captured.out - - def test_print_text_without_color(self, cb: LoggingCallback, capsys): - """print_text without colour should print plain text.""" - cb.print_text("plain text") - captured = capsys.readouterr() - assert "plain text" in captured.out - - def test_print_text_all_colours(self, cb: LoggingCallback, capsys): - """Verify all supported colour keys don't raise.""" - for colour in ("blue", "yellow", "pink", "green", "red"): - cb.print_text("x", color=colour) - captured = capsys.readouterr() - # All outputs should contain 'x' (5 calls) - assert captured.out.count("x") >= 5 - - -# =========================================================================== -# Integration-style test: real print_text called (no mocking) -# =========================================================================== - - -class TestLoggingCallbackIntegration: - """Light integration tests – real print_text calls, just checking no exceptions.""" - - def test_on_before_invoke_full_run(self, capsys): - """Full on_before_invoke run with all optional fields – verifies real output.""" - cb = LoggingCallback() - llm = MagicMock() - msgs = [_make_user_prompt("Who are you?", name="tester")] - tools = [_make_tool("calculator")] - cb.on_before_invoke( - llm_instance=llm, - model="gpt-4-turbo", - credentials={"api_key": "sk-xxx"}, - prompt_messages=msgs, - model_parameters={"temperature": 0.8}, - tools=tools, - stop=["STOP"], - stream=True, - user="test_user", - ) - captured = capsys.readouterr() - assert "gpt-4-turbo" in captured.out - assert "calculator" in captured.out - assert "test_user" in captured.out - assert "STOP" in captured.out - assert "tester" in captured.out - - def test_on_new_chunk_full_run(self, capsys): - """Full on_new_chunk run – verifies real stdout write.""" - cb = LoggingCallback() - chunk = _make_chunk("streaming token") - cb.on_new_chunk( - llm_instance=MagicMock(), - chunk=chunk, - model="gpt-4", - credentials={}, - prompt_messages=[], - model_parameters={}, - ) - captured = capsys.readouterr() - assert "streaming token" in captured.out - - def test_on_after_invoke_full_run_with_tool_calls(self, capsys): - """Full on_after_invoke run with tool calls – verifies real output.""" - cb = LoggingCallback() - tc = _make_tool_call("call-99", "do_thing", '{"n": 5}') - result = _make_llm_result(content="result content", tool_calls=[tc], system_fingerprint="fp-xyz") - cb.on_after_invoke( - llm_instance=MagicMock(), - result=result, - model=result.model, - credentials={}, - prompt_messages=[], - model_parameters={}, - ) - captured = capsys.readouterr() - assert "result content" in captured.out - assert "call-99" in captured.out - assert "do_thing" in captured.out - assert "fp-xyz" in captured.out - - def test_on_invoke_error_full_run(self, capsys): - """Full on_invoke_error run – just verifies no exception is raised.""" - cb = LoggingCallback() - ex = RuntimeError("something bad happened") - # logger.exception writes to stderr; we just confirm it doesn't crash - cb.on_invoke_error( - llm_instance=MagicMock(), - ex=ex, - model="gpt-4", - credentials={}, - prompt_messages=[], - model_parameters={}, - ) - captured = capsys.readouterr() - assert "[on_llm_invoke_error]" in captured.out diff --git a/api/tests/unit_tests/graphon/model_runtime/entities/test_common_entities.py b/api/tests/unit_tests/graphon/model_runtime/entities/test_common_entities.py deleted file mode 100644 index 7d6255c37a..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/entities/test_common_entities.py +++ /dev/null @@ -1,35 +0,0 @@ -from graphon.model_runtime.entities.common_entities import I18nObject - - -class TestI18nObject: - def test_i18n_object_with_both_languages(self): - """ - Test I18nObject when both zh_Hans and en_US are provided. - """ - i18n = I18nObject(zh_Hans="你好", en_US="Hello") - assert i18n.zh_Hans == "你好" - assert i18n.en_US == "Hello" - - def test_i18n_object_fallback_to_en_us(self): - """ - Test I18nObject when zh_Hans is missing, it should fallback to en_US. - """ - i18n = I18nObject(en_US="Hello") - assert i18n.zh_Hans == "Hello" - assert i18n.en_US == "Hello" - - def test_i18n_object_with_none_zh_hans(self): - """ - Test I18nObject when zh_Hans is None, it should fallback to en_US. - """ - i18n = I18nObject(zh_Hans=None, en_US="Hello") - assert i18n.zh_Hans == "Hello" - assert i18n.en_US == "Hello" - - def test_i18n_object_with_empty_zh_hans(self): - """ - Test I18nObject when zh_Hans is an empty string, it should fallback to en_US. - """ - i18n = I18nObject(zh_Hans="", en_US="Hello") - assert i18n.zh_Hans == "Hello" - assert i18n.en_US == "Hello" diff --git a/api/tests/unit_tests/graphon/model_runtime/entities/test_llm_entities.py b/api/tests/unit_tests/graphon/model_runtime/entities/test_llm_entities.py deleted file mode 100644 index 51a6c38fa9..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/entities/test_llm_entities.py +++ /dev/null @@ -1,148 +0,0 @@ -"""Tests for LLMUsage entity.""" - -from decimal import Decimal - -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata - - -class TestLLMUsage: - """Test cases for LLMUsage class.""" - - def test_from_metadata_with_all_tokens(self): - """Test from_metadata when all token types are provided.""" - metadata: LLMUsageMetadata = { - "prompt_tokens": 100, - "completion_tokens": 50, - "total_tokens": 150, - "prompt_unit_price": 0.001, - "completion_unit_price": 0.002, - "total_price": 0.2, - "currency": "USD", - "latency": 1.5, - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 100 - assert usage.completion_tokens == 50 - assert usage.total_tokens == 150 - assert usage.prompt_unit_price == Decimal("0.001") - assert usage.completion_unit_price == Decimal("0.002") - assert usage.total_price == Decimal("0.2") - assert usage.currency == "USD" - assert usage.latency == 1.5 - - def test_from_metadata_with_prompt_tokens_only(self): - """Test from_metadata when only prompt_tokens is provided.""" - metadata: LLMUsageMetadata = { - "prompt_tokens": 100, - "total_tokens": 100, - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 100 - assert usage.completion_tokens == 0 - assert usage.total_tokens == 100 - - def test_from_metadata_with_completion_tokens_only(self): - """Test from_metadata when only completion_tokens is provided.""" - metadata: LLMUsageMetadata = { - "completion_tokens": 50, - "total_tokens": 50, - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 0 - assert usage.completion_tokens == 50 - assert usage.total_tokens == 50 - - def test_from_metadata_calculates_total_when_missing(self): - """Test from_metadata calculates total_tokens when not provided.""" - metadata: LLMUsageMetadata = { - "prompt_tokens": 100, - "completion_tokens": 50, - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 100 - assert usage.completion_tokens == 50 - assert usage.total_tokens == 150 # Should be calculated - - def test_from_metadata_with_total_but_no_completion(self): - """ - Test from_metadata when total_tokens is provided but completion_tokens is 0. - This tests the fix for issue #24360 - prompt tokens should NOT be assigned to completion_tokens. - """ - metadata: LLMUsageMetadata = { - "prompt_tokens": 479, - "completion_tokens": 0, - "total_tokens": 521, - } - - usage = LLMUsage.from_metadata(metadata) - - # This is the key fix - prompt tokens should remain as prompt tokens - assert usage.prompt_tokens == 479 - assert usage.completion_tokens == 0 - assert usage.total_tokens == 521 - - def test_from_metadata_with_empty_metadata(self): - """Test from_metadata with empty metadata.""" - metadata: LLMUsageMetadata = {} - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 0 - assert usage.completion_tokens == 0 - assert usage.total_tokens == 0 - assert usage.currency == "USD" - assert usage.latency == 0.0 - - def test_from_metadata_preserves_zero_completion_tokens(self): - """ - Test that zero completion_tokens are preserved when explicitly set. - This is important for agent nodes that only use prompt tokens. - """ - metadata: LLMUsageMetadata = { - "prompt_tokens": 1000, - "completion_tokens": 0, - "total_tokens": 1000, - "prompt_unit_price": 0.15, - "completion_unit_price": 0.60, - "prompt_price": 0.00015, - "completion_price": 0, - "total_price": 0.00015, - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_tokens == 1000 - assert usage.completion_tokens == 0 - assert usage.total_tokens == 1000 - assert usage.prompt_price == Decimal("0.00015") - assert usage.completion_price == Decimal(0) - assert usage.total_price == Decimal("0.00015") - - def test_from_metadata_with_decimal_values(self): - """Test from_metadata handles decimal values correctly.""" - metadata: LLMUsageMetadata = { - "prompt_tokens": 100, - "completion_tokens": 50, - "total_tokens": 150, - "prompt_unit_price": "0.001", - "completion_unit_price": "0.002", - "prompt_price": "0.1", - "completion_price": "0.1", - "total_price": "0.2", - } - - usage = LLMUsage.from_metadata(metadata) - - assert usage.prompt_unit_price == Decimal("0.001") - assert usage.completion_unit_price == Decimal("0.002") - assert usage.prompt_price == Decimal("0.1") - assert usage.completion_price == Decimal("0.1") - assert usage.total_price == Decimal("0.2") diff --git a/api/tests/unit_tests/graphon/model_runtime/entities/test_message_entities.py b/api/tests/unit_tests/graphon/model_runtime/entities/test_message_entities.py deleted file mode 100644 index 1918c324cc..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/entities/test_message_entities.py +++ /dev/null @@ -1,210 +0,0 @@ -import pytest - -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - AudioPromptMessageContent, - DocumentPromptMessageContent, - ImagePromptMessageContent, - PromptMessageContent, - PromptMessageContentType, - PromptMessageFunction, - PromptMessageRole, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, - VideoPromptMessageContent, -) - - -class TestPromptMessageRole: - def test_value_of(self): - assert PromptMessageRole.value_of("system") == PromptMessageRole.SYSTEM - assert PromptMessageRole.value_of("user") == PromptMessageRole.USER - assert PromptMessageRole.value_of("assistant") == PromptMessageRole.ASSISTANT - assert PromptMessageRole.value_of("tool") == PromptMessageRole.TOOL - - with pytest.raises(ValueError, match="invalid prompt message type value invalid"): - PromptMessageRole.value_of("invalid") - - -class TestPromptMessageEntities: - def test_prompt_message_tool(self): - tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"}) - assert tool.name == "test_tool" - assert tool.description == "test desc" - assert tool.parameters == {"foo": "bar"} - - def test_prompt_message_function(self): - tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"}) - func = PromptMessageFunction(function=tool) - assert func.type == "function" - assert func.function == tool - - -class TestPromptMessageContent: - def test_text_content(self): - content = TextPromptMessageContent(data="hello") - assert content.type == PromptMessageContentType.TEXT - assert content.data == "hello" - - def test_image_content(self): - content = ImagePromptMessageContent( - format="jpg", base64_data="abc", mime_type="image/jpeg", detail=ImagePromptMessageContent.DETAIL.HIGH - ) - assert content.type == PromptMessageContentType.IMAGE - assert content.detail == ImagePromptMessageContent.DETAIL.HIGH - assert content.data == "data:image/jpeg;base64,abc" - - def test_image_content_url(self): - content = ImagePromptMessageContent(format="jpg", url="https://example.com/image.jpg", mime_type="image/jpeg") - assert content.data == "https://example.com/image.jpg" - - def test_audio_content(self): - content = AudioPromptMessageContent(format="mp3", base64_data="abc", mime_type="audio/mpeg") - assert content.type == PromptMessageContentType.AUDIO - assert content.data == "data:audio/mpeg;base64,abc" - - def test_video_content(self): - content = VideoPromptMessageContent(format="mp4", base64_data="abc", mime_type="video/mp4") - assert content.type == PromptMessageContentType.VIDEO - assert content.data == "data:video/mp4;base64,abc" - - def test_document_content(self): - content = DocumentPromptMessageContent(format="pdf", base64_data="abc", mime_type="application/pdf") - assert content.type == PromptMessageContentType.DOCUMENT - assert content.data == "data:application/pdf;base64,abc" - - -class TestPromptMessages: - def test_user_prompt_message(self): - msg = UserPromptMessage(content="hello") - assert msg.role == PromptMessageRole.USER - assert msg.content == "hello" - assert msg.is_empty() is False - assert msg.get_text_content() == "hello" - - def test_user_prompt_message_complex_content(self): - content = [TextPromptMessageContent(data="hello "), TextPromptMessageContent(data="world")] - msg = UserPromptMessage(content=content) - assert msg.get_text_content() == "hello world" - - # Test validation from dict - msg2 = UserPromptMessage(content=[{"type": "text", "data": "hi"}]) - assert isinstance(msg2.content[0], TextPromptMessageContent) - assert msg2.content[0].data == "hi" - - def test_prompt_message_empty(self): - msg = UserPromptMessage(content=None) - assert msg.is_empty() is True - assert msg.get_text_content() == "" - - def test_assistant_prompt_message(self): - msg = AssistantPromptMessage(content="thinking...") - assert msg.role == PromptMessageRole.ASSISTANT - assert msg.is_empty() is False - - tool_call = AssistantPromptMessage.ToolCall( - id="call_1", - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"), - ) - msg_with_tools = AssistantPromptMessage(content=None, tool_calls=[tool_call]) - assert msg_with_tools.is_empty() is False - assert msg_with_tools.role == PromptMessageRole.ASSISTANT - - def test_assistant_tool_call_id_transform(self): - tool_call = AssistantPromptMessage.ToolCall( - id=123, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"), - ) - assert tool_call.id == "123" - - def test_system_prompt_message(self): - msg = SystemPromptMessage(content="you are a bot") - assert msg.role == PromptMessageRole.SYSTEM - assert msg.content == "you are a bot" - - def test_tool_prompt_message(self): - # Case 1: Both content and tool_call_id are present - msg = ToolPromptMessage(content="result", tool_call_id="call_1") - assert msg.role == PromptMessageRole.TOOL - assert msg.tool_call_id == "call_1" - assert msg.is_empty() is False - - # Case 2: Content is present, but tool_call_id is empty - msg_content_only = ToolPromptMessage(content="result", tool_call_id="") - assert msg_content_only.is_empty() is False - - # Case 3: Content is None, but tool_call_id is present - msg_id_only = ToolPromptMessage(content=None, tool_call_id="call_1") - assert msg_id_only.is_empty() is False - - # Case 4: Both content and tool_call_id are empty - msg_empty = ToolPromptMessage(content=None, tool_call_id="") - assert msg_empty.is_empty() is True - - def test_prompt_message_validation_errors(self): - with pytest.raises(KeyError): - # Invalid content type in list - UserPromptMessage(content=[{"type": "invalid", "data": "foo"}]) - - with pytest.raises(ValueError, match="invalid prompt message"): - # Not a dict or PromptMessageContent - UserPromptMessage(content=[123]) - - def test_prompt_message_serialization(self): - # Case: content is None - assert UserPromptMessage(content=None).serialize_content(None) is None - - # Case: content is str - assert UserPromptMessage(content="hello").serialize_content("hello") == "hello" - - # Case: content is list of dict - content_list = [{"type": "text", "data": "hi"}] - msg = UserPromptMessage(content=content_list) - assert msg.serialize_content(msg.content) == [{"type": PromptMessageContentType.TEXT, "data": "hi"}] - - # Case: content is Sequence but not list (e.g. tuple) - # To hit line 204, we can call serialize_content manually or - # try to pass a type that pydantic doesn't convert to list in its internal state. - # Actually, let's just call it manually on the instance. - msg = UserPromptMessage(content="test") - content_tuple = (TextPromptMessageContent(data="hi"),) - assert msg.serialize_content(content_tuple) == content_tuple - - def test_prompt_message_mixed_content_validation(self): - # Test branch: isinstance(prompt, PromptMessageContent) - # but not (TextPromptMessageContent | MultiModalPromptMessageContent) - # Line 187: prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump()) - - # We need a PromptMessageContent that is NOT Text or MultiModal. - # But PromptMessageContentUnionTypes discriminator handles this usually. - # We can bypass high-level validation by passing the object directly in a list. - - class MockContent(PromptMessageContent): - type: PromptMessageContentType = PromptMessageContentType.TEXT - data: str - - mock_item = MockContent(data="test") - msg = UserPromptMessage(content=[mock_item]) - # It should hit line 187 and convert to TextPromptMessageContent - assert isinstance(msg.content[0], TextPromptMessageContent) - assert msg.content[0].data == "test" - - def test_prompt_message_get_text_content_branches(self): - # content is None - msg_none = UserPromptMessage(content=None) - assert msg_none.get_text_content() == "" - - # content is list but no text content - image = ImagePromptMessageContent(format="jpg", base64_data="abc", mime_type="image/jpeg") - msg_image = UserPromptMessage(content=[image]) - assert msg_image.get_text_content() == "" - - # content is list with mixed - text = TextPromptMessageContent(data="hello") - msg_mixed = UserPromptMessage(content=[text, image]) - assert msg_mixed.get_text_content() == "hello" diff --git a/api/tests/unit_tests/graphon/model_runtime/entities/test_model_entities.py b/api/tests/unit_tests/graphon/model_runtime/entities/test_model_entities.py deleted file mode 100644 index 1988709faa..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/entities/test_model_entities.py +++ /dev/null @@ -1,220 +0,0 @@ -from decimal import Decimal - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ( - AIModelEntity, - DefaultParameterName, - FetchFrom, - ModelFeature, - ModelPropertyKey, - ModelType, - ModelUsage, - ParameterRule, - ParameterType, - PriceConfig, - PriceInfo, - PriceType, - ProviderModel, -) - - -class TestModelType: - def test_value_of(self): - assert ModelType.value_of("text-generation") == ModelType.LLM - assert ModelType.value_of(ModelType.LLM) == ModelType.LLM - assert ModelType.value_of("embeddings") == ModelType.TEXT_EMBEDDING - assert ModelType.value_of(ModelType.TEXT_EMBEDDING) == ModelType.TEXT_EMBEDDING - assert ModelType.value_of("reranking") == ModelType.RERANK - assert ModelType.value_of(ModelType.RERANK) == ModelType.RERANK - assert ModelType.value_of("speech2text") == ModelType.SPEECH2TEXT - assert ModelType.value_of(ModelType.SPEECH2TEXT) == ModelType.SPEECH2TEXT - assert ModelType.value_of("tts") == ModelType.TTS - assert ModelType.value_of(ModelType.TTS) == ModelType.TTS - assert ModelType.value_of(ModelType.MODERATION) == ModelType.MODERATION - - with pytest.raises(ValueError, match="invalid origin model type invalid"): - ModelType.value_of("invalid") - - def test_to_origin_model_type(self): - assert ModelType.LLM.to_origin_model_type() == "text-generation" - assert ModelType.TEXT_EMBEDDING.to_origin_model_type() == "embeddings" - assert ModelType.RERANK.to_origin_model_type() == "reranking" - assert ModelType.SPEECH2TEXT.to_origin_model_type() == "speech2text" - assert ModelType.TTS.to_origin_model_type() == "tts" - assert ModelType.MODERATION.to_origin_model_type() == "moderation" - - # Testing the else branch in to_origin_model_type - # Since it's a StrEnum, it's hard to get an invalid value here unless we mock or Force it. - # But if we look at the implementation: - # if self == self.LLM: ... elif ... else: raise ValueError - # We can try to create a "dummy" member if possible, or just skip it if we have 100% coverage otherwise. - # Actually, adding a new member to an enum at runtime is possible but messy. - # Let's see if we can trigger it. - - -class TestFetchFrom: - def test_values(self): - assert FetchFrom.PREDEFINED_MODEL == "predefined-model" - assert FetchFrom.CUSTOMIZABLE_MODEL == "customizable-model" - - -class TestModelFeature: - def test_values(self): - assert ModelFeature.TOOL_CALL == "tool-call" - assert ModelFeature.MULTI_TOOL_CALL == "multi-tool-call" - assert ModelFeature.AGENT_THOUGHT == "agent-thought" - assert ModelFeature.VISION == "vision" - assert ModelFeature.STREAM_TOOL_CALL == "stream-tool-call" - assert ModelFeature.DOCUMENT == "document" - assert ModelFeature.VIDEO == "video" - assert ModelFeature.AUDIO == "audio" - assert ModelFeature.STRUCTURED_OUTPUT == "structured-output" - - -class TestDefaultParameterName: - def test_value_of(self): - assert DefaultParameterName.value_of("temperature") == DefaultParameterName.TEMPERATURE - assert DefaultParameterName.value_of("top_p") == DefaultParameterName.TOP_P - - with pytest.raises(ValueError, match="invalid parameter name invalid"): - DefaultParameterName.value_of("invalid") - - -class TestParameterType: - def test_values(self): - assert ParameterType.FLOAT == "float" - assert ParameterType.INT == "int" - assert ParameterType.STRING == "string" - assert ParameterType.BOOLEAN == "boolean" - assert ParameterType.TEXT == "text" - - -class TestModelPropertyKey: - def test_values(self): - assert ModelPropertyKey.MODE == "mode" - assert ModelPropertyKey.CONTEXT_SIZE == "context_size" - - -class TestProviderModel: - def test_provider_model(self): - model = ProviderModel( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - ) - assert model.model == "gpt-4" - assert model.support_structure_output is False - - model_with_features = ProviderModel( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - features=[ModelFeature.STRUCTURED_OUTPUT], - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - ) - assert model_with_features.support_structure_output is True - - -class TestParameterRule: - def test_parameter_rule(self): - rule = ParameterRule( - name="temperature", - label=I18nObject(en_US="Temperature"), - type=ParameterType.FLOAT, - default=0.7, - min=0.0, - max=1.0, - precision=2, - ) - assert rule.name == "temperature" - assert rule.default == 0.7 - - -class TestPriceConfig: - def test_price_config(self): - config = PriceConfig(input=Decimal("0.01"), output=Decimal("0.02"), unit=Decimal("0.001"), currency="USD") - assert config.input == Decimal("0.01") - assert config.output == Decimal("0.02") - - -class TestAIModelEntity: - def test_ai_model_entity_no_json_schema(self): - entity = AIModelEntity( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - parameter_rules=[ - ParameterRule(name="temperature", label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT) - ], - ) - assert ModelFeature.STRUCTURED_OUTPUT not in (entity.features or []) - - def test_ai_model_entity_with_json_schema(self): - # Case: json_schema in parameter rules, features is None - entity = AIModelEntity( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - parameter_rules=[ - ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) - ], - ) - assert ModelFeature.STRUCTURED_OUTPUT in entity.features - - def test_ai_model_entity_with_json_schema_and_features_empty(self): - # Case: json_schema in parameter rules, features is empty list - entity = AIModelEntity( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - features=[], - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - parameter_rules=[ - ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) - ], - ) - assert ModelFeature.STRUCTURED_OUTPUT in entity.features - - def test_ai_model_entity_with_json_schema_and_other_features(self): - # Case: json_schema in parameter rules, features has other things - entity = AIModelEntity( - model="gpt-4", - label=I18nObject(en_US="GPT-4"), - model_type=ModelType.LLM, - features=[ModelFeature.VISION], - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, - parameter_rules=[ - ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) - ], - ) - assert ModelFeature.STRUCTURED_OUTPUT in entity.features - assert ModelFeature.VISION in entity.features - - -class TestModelUsage: - def test_model_usage(self): - usage = ModelUsage() - assert isinstance(usage, ModelUsage) - - -class TestPriceType: - def test_values(self): - assert PriceType.INPUT == "input" - assert PriceType.OUTPUT == "output" - - -class TestPriceInfo: - def test_price_info(self): - info = PriceInfo(unit_price=Decimal("0.01"), unit=Decimal(1000), total_amount=Decimal("0.05"), currency="USD") - assert info.total_amount == Decimal("0.05") diff --git a/api/tests/unit_tests/graphon/model_runtime/errors/test_invoke.py b/api/tests/unit_tests/graphon/model_runtime/errors/test_invoke.py deleted file mode 100644 index 2004822230..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/errors/test_invoke.py +++ /dev/null @@ -1,63 +0,0 @@ -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) - - -class TestInvokeErrors: - def test_invoke_error_with_description(self): - error = InvokeError("Custom description") - assert error.description == "Custom description" - assert str(error) == "Custom description" - assert isinstance(error, ValueError) - - def test_invoke_error_without_description(self): - error = InvokeError() - assert error.description is None - assert str(error) == "InvokeError" - - def test_invoke_connection_error(self): - # Now preserves class-level description - error = InvokeConnectionError() - assert error.description == "Connection Error" - assert str(error) == "Connection Error" - assert isinstance(error, InvokeError) - - # Test with explicit description - error_with_desc = InvokeConnectionError("Connection Error") - assert error_with_desc.description == "Connection Error" - assert str(error_with_desc) == "Connection Error" - - def test_invoke_server_unavailable_error(self): - error = InvokeServerUnavailableError() - assert error.description == "Server Unavailable Error" - assert str(error) == "Server Unavailable Error" - assert isinstance(error, InvokeError) - - def test_invoke_rate_limit_error(self): - error = InvokeRateLimitError() - assert error.description == "Rate Limit Error" - assert str(error) == "Rate Limit Error" - assert isinstance(error, InvokeError) - - def test_invoke_authorization_error(self): - error = InvokeAuthorizationError() - assert error.description == "Incorrect model credentials provided, please check and try again. " - assert str(error) == "Incorrect model credentials provided, please check and try again. " - assert isinstance(error, InvokeError) - - def test_invoke_bad_request_error(self): - error = InvokeBadRequestError() - assert error.description == "Bad Request Error" - assert str(error) == "Bad Request Error" - assert isinstance(error, InvokeError) - - def test_invoke_error_inheritance(self): - # Test that we can override the default description in subclasses - error = InvokeBadRequestError("Overridden Error") - assert error.description == "Overridden Error" - assert str(error) == "Overridden Error" diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_ai_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_ai_model.py deleted file mode 100644 index 64edd69789..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_ai_model.py +++ /dev/null @@ -1,254 +0,0 @@ -import decimal -from unittest.mock import MagicMock, patch - -import pytest -from pydantic import BaseModel - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ( - AIModelEntity, - DefaultParameterName, - FetchFrom, - ModelPropertyKey, - ModelType, - ParameterRule, - ParameterType, - PriceConfig, - PriceType, -) -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel - - -class _ConcreteAIModel(AIModel): - model_type = ModelType.LLM - - -class TestAIModel: - @pytest.fixture - def provider_schema(self) -> ProviderEntity: - return ProviderEntity( - provider="langgenius/openai/openai", - provider_name="openai", - label=I18nObject(en_US="OpenAI"), - supported_model_types=[ModelType.LLM], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - @pytest.fixture - def model_runtime(self) -> MagicMock: - return MagicMock() - - @pytest.fixture - def ai_model(self, provider_schema: ProviderEntity, model_runtime: MagicMock) -> AIModel: - return _ConcreteAIModel( - provider_schema=provider_schema, - model_runtime=model_runtime, - ) - - def test_init_stores_runtime_state_and_is_not_pydantic_model( - self, ai_model: AIModel, provider_schema: ProviderEntity, model_runtime: MagicMock - ) -> None: - assert ai_model.model_type == ModelType.LLM - assert ai_model.provider_schema is provider_schema - assert ai_model.model_runtime is model_runtime - assert ai_model.provider == "langgenius/openai/openai" - assert ai_model.provider_display_name == "OpenAI" - assert ai_model.started_at == 0 - assert not isinstance(ai_model, BaseModel) - - def test_direct_base_class_requires_subclass_model_type( - self, provider_schema: ProviderEntity, model_runtime: MagicMock - ) -> None: - with pytest.raises(TypeError, match="subclasses must define model_type"): - AIModel(provider_schema=provider_schema, model_runtime=model_runtime) - - def test_subclass_uses_class_level_model_type( - self, provider_schema: ProviderEntity, model_runtime: MagicMock - ) -> None: - model = _ConcreteAIModel(provider_schema=provider_schema, model_runtime=model_runtime) - assert model.model_type == ModelType.LLM - - def test_invoke_error_mapping(self, ai_model: AIModel) -> None: - mapping = ai_model._invoke_error_mapping - assert InvokeConnectionError in mapping - assert InvokeServerUnavailableError in mapping - assert InvokeRateLimitError in mapping - assert InvokeAuthorizationError in mapping - assert InvokeBadRequestError in mapping - assert ValueError in mapping - - def test_transform_invoke_error(self, ai_model: AIModel) -> None: - err = Exception("Original error") - - with patch.object(AIModel, "_invoke_error_mapping", {InvokeAuthorizationError: [Exception]}): - transformed = ai_model._transform_invoke_error(err) - assert isinstance(transformed, InvokeAuthorizationError) - assert "Incorrect model credentials provided" in str(transformed.description) - - class CustomNonInvokeError(Exception): - pass - - with patch.object(AIModel, "_invoke_error_mapping", {CustomNonInvokeError: [Exception]}): - transformed = ai_model._transform_invoke_error(err) - assert transformed == err - - transformed = ai_model._transform_invoke_error(Exception("Unmapped")) - assert isinstance(transformed, InvokeError) - assert transformed.description == "[OpenAI] Error: Unmapped" - - def test_get_price(self, ai_model: AIModel) -> None: - model_name = "test_model" - credentials = {"key": "value"} - - mock_schema = MagicMock(spec=AIModelEntity) - mock_schema.pricing = PriceConfig( - input=decimal.Decimal("0.002"), - output=decimal.Decimal("0.004"), - unit=decimal.Decimal(1000), - currency="USD", - ) - - with patch.object(AIModel, "get_model_schema", return_value=mock_schema): - price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 2000) - assert price_info.unit_price == decimal.Decimal("0.002") - - price_info = ai_model.get_price(model_name, credentials, PriceType.OUTPUT, 2000) - assert price_info.unit_price == decimal.Decimal("0.004") - - mock_schema.pricing = None - with patch.object(AIModel, "get_model_schema", return_value=mock_schema): - price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 1000) - assert price_info.total_amount == decimal.Decimal("0.0") - - def test_get_price_no_price_config_error(self, ai_model: AIModel) -> None: - class ChangingPriceConfig: - def __init__(self) -> None: - self.input = decimal.Decimal("0.01") - self.unit = decimal.Decimal(1) - self.currency = "USD" - self.called = 0 - - def __bool__(self) -> bool: - self.called += 1 - return self.called <= 2 - - mock_schema = MagicMock() - mock_schema.pricing = ChangingPriceConfig() - - with patch.object(AIModel, "get_model_schema", return_value=mock_schema): - with pytest.raises(ValueError, match="Price config not found"): - ai_model.get_price("test_model", {}, PriceType.INPUT, 1000) - - def test_get_model_schema_delegates_to_runtime( - self, ai_model: AIModel, model_runtime: MagicMock, provider_schema: ProviderEntity - ) -> None: - model_name = "test_model" - credentials = {"api_key": "abc"} - - mock_schema = AIModelEntity( - model="test_model", - label=I18nObject(en_US="Test Model"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, - parameter_rules=[], - ) - model_runtime.get_model_schema.return_value = mock_schema - - schema = ai_model.get_model_schema(model_name, credentials) - - assert schema == mock_schema - model_runtime.get_model_schema.assert_called_once_with( - provider=provider_schema.provider, - model_type=ModelType.LLM, - model=model_name, - credentials=credentials, - ) - - def test_get_customizable_model_schema_from_credentials_template_mapping_value_error( - self, ai_model: AIModel - ) -> None: - mock_schema = AIModelEntity( - model="test_model", - label=I18nObject(en_US="Test Model"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, - parameter_rules=[ - ParameterRule( - name="invalid", - use_template="invalid_template_name", - label=I18nObject(en_US="Invalid"), - type=ParameterType.FLOAT, - ) - ], - ) - - with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema): - schema = ai_model.get_customizable_model_schema_from_credentials("test_model", {}) - assert schema is not None - assert schema.parameter_rules[0].use_template == "invalid_template_name" - - def test_get_customizable_model_schema_from_credentials(self, ai_model: AIModel) -> None: - mock_schema = AIModelEntity( - model="test_model", - label=I18nObject(en_US="Test Model"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, - parameter_rules=[ - ParameterRule( - name="temp", use_template="temperature", label=I18nObject(en_US="Temp"), type=ParameterType.FLOAT - ), - ParameterRule( - name="top_p", - use_template="top_p", - label=I18nObject(en_US="Top P"), - type=ParameterType.FLOAT, - help=I18nObject(en_US=""), - ), - ParameterRule( - name="max_tokens", - use_template="max_tokens", - label=I18nObject(en_US="Max Tokens"), - type=ParameterType.INT, - help=I18nObject(en_US="", zh_Hans=""), - ), - ParameterRule(name="custom", label=I18nObject(en_US="Custom"), type=ParameterType.STRING), - ], - ) - - with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema): - schema = ai_model.get_customizable_model_schema_from_credentials("test_model", {}) - - assert schema is not None - assert schema.parameter_rules[0].max == 1.0 - assert schema.parameter_rules[1].help is not None - assert schema.parameter_rules[1].help.en_US != "" - assert schema.parameter_rules[2].help is not None - assert schema.parameter_rules[2].help.zh_Hans != "" - assert schema.parameter_rules[3].use_template is None - - def test_get_customizable_model_schema_from_credentials_none(self, ai_model: AIModel) -> None: - with patch.object(AIModel, "get_customizable_model_schema", return_value=None): - schema = ai_model.get_customizable_model_schema_from_credentials("model", {}) - assert schema is None - - def test_get_customizable_model_schema_default(self, ai_model: AIModel) -> None: - assert ai_model.get_customizable_model_schema("model", {}) is None - - def test_get_default_parameter_rule_variable_map(self, ai_model: AIModel) -> None: - result = ai_model._get_default_parameter_rule_variable_map(DefaultParameterName.TEMPERATURE) - assert result["default"] == 0.0 - - with pytest.raises(Exception, match="Invalid model parameter rule name"): - ai_model._get_default_parameter_rule_variable_map("invalid_name") diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_large_language_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_large_language_model.py deleted file mode 100644 index 668a7e3476..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_large_language_model.py +++ /dev/null @@ -1,452 +0,0 @@ -import logging -from collections.abc import Generator, Iterator, Sequence -from dataclasses import dataclass, field -from decimal import Decimal -from types import SimpleNamespace -from typing import Any -from unittest.mock import MagicMock - -import pytest - -import graphon.model_runtime.model_providers.__base.large_language_model as llm_module - -# Access large_language_model members via llm_module to avoid partial import issues in CI -from graphon.model_runtime.callbacks.base_callback import Callback -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMUsage, -) -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelType, PriceInfo -from graphon.model_runtime.model_providers.__base.large_language_model import _build_llm_result_from_chunks - - -def _usage(prompt_tokens: int = 1, completion_tokens: int = 2) -> LLMUsage: - return LLMUsage( - prompt_tokens=prompt_tokens, - prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal(1), - prompt_price=Decimal(prompt_tokens) * Decimal("0.001"), - completion_tokens=completion_tokens, - completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal(1), - completion_price=Decimal(completion_tokens) * Decimal("0.002"), - total_tokens=prompt_tokens + completion_tokens, - total_price=Decimal(prompt_tokens) * Decimal("0.001") + Decimal(completion_tokens) * Decimal("0.002"), - currency="USD", - latency=0.0, - ) - - -def _tool_call_delta( - *, - tool_call_id: str, - tool_type: str = "function", - function_name: str = "", - function_arguments: str = "", -) -> AssistantPromptMessage.ToolCall: - return AssistantPromptMessage.ToolCall( - id=tool_call_id, - type=tool_type, - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=function_name, arguments=function_arguments), - ) - - -def _chunk( - *, - model: str = "test-model", - content: str | list[Any] | None = None, - tool_calls: list[AssistantPromptMessage.ToolCall] | None = None, - usage: LLMUsage | None = None, - system_fingerprint: str | None = None, -) -> LLMResultChunk: - return LLMResultChunk( - model=model, - system_fingerprint=system_fingerprint, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=content, tool_calls=tool_calls or []), - usage=usage, - ), - ) - - -@dataclass -class SpyCallback(Callback): - raise_error: bool = False - before: list[dict[str, Any]] = field(default_factory=list) - new_chunk: list[dict[str, Any]] = field(default_factory=list) - after: list[dict[str, Any]] = field(default_factory=list) - error: list[dict[str, Any]] = field(default_factory=list) - - def on_before_invoke(self, **kwargs: Any) -> None: # type: ignore[override] - self.before.append(kwargs) - - def on_new_chunk(self, **kwargs: Any) -> None: # type: ignore[override] - self.new_chunk.append(kwargs) - - def on_after_invoke(self, **kwargs: Any) -> None: # type: ignore[override] - self.after.append(kwargs) - - def on_invoke_error(self, **kwargs: Any) -> None: # type: ignore[override] - self.error.append(kwargs) - - -class _TestLLM(llm_module.LargeLanguageModel): - def get_price(self, model: str, credentials: dict, price_type: Any, tokens: int) -> PriceInfo: # type: ignore[override] - return PriceInfo( - unit_price=Decimal("0.01"), - unit=Decimal(1), - total_amount=Decimal(tokens) * Decimal("0.01"), - currency="USD", - ) - - def _transform_invoke_error(self, error: Exception) -> Exception: # type: ignore[override] - return RuntimeError(f"transformed: {error}") - - -@pytest.fixture -def llm() -> _TestLLM: - provider_schema = SimpleNamespace(provider="provider", label=SimpleNamespace(en_US="Provider")) - model_runtime = MagicMock() - model_runtime.get_llm_num_tokens.return_value = 0 - return _TestLLM(provider_schema=provider_schema, model_runtime=model_runtime, started_at=1.0) - - -def test_gen_tool_call_id_is_uuid_based(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="abc123")) - assert llm_module._gen_tool_call_id() == "chatcmpl-tool-abc123" - - -def test_run_callbacks_no_callbacks_noop() -> None: - invoked: list[int] = [] - llm_module._run_callbacks(None, event="x", invoke=lambda _: invoked.append(1)) - llm_module._run_callbacks([], event="x", invoke=lambda _: invoked.append(1)) - assert invoked == [] - - -def test_run_callbacks_swallows_error_when_raise_error_false(caplog: pytest.LogCaptureFixture) -> None: - class Boom: - raise_error = False - - caplog.set_level(logging.WARNING) - llm_module._run_callbacks( - [Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom")) - ) - assert any("Callback" in record.message and "failed with error" in record.message for record in caplog.records) - - -def test_run_callbacks_reraises_when_raise_error_true() -> None: - class Boom: - raise_error = True - - with pytest.raises(ValueError, match="boom"): - llm_module._run_callbacks( - [Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom")) - ) - - -def test_get_or_create_tool_call_empty_id_returns_last() -> None: - calls = [ - _tool_call_delta(tool_call_id="id1", function_name="a"), - _tool_call_delta(tool_call_id="id2", function_name="b"), - ] - assert llm_module._get_or_create_tool_call(calls, "") is calls[-1] - - -def test_get_or_create_tool_call_empty_id_without_existing_raises() -> None: - with pytest.raises(ValueError, match="tool_call_id is empty"): - llm_module._get_or_create_tool_call([], "") - - -def test_get_or_create_tool_call_creates_if_missing() -> None: - calls: list[AssistantPromptMessage.ToolCall] = [] - tool_call = llm_module._get_or_create_tool_call(calls, "new-id") - assert tool_call.id == "new-id" - assert tool_call.function.name == "" - assert tool_call.function.arguments == "" - assert calls == [tool_call] - - -def test_get_or_create_tool_call_returns_existing_when_found() -> None: - existing = _tool_call_delta(tool_call_id="same-id", function_name="fn", function_arguments="{}") - calls = [existing] - assert llm_module._get_or_create_tool_call(calls, "same-id") is existing - - -def test_merge_tool_call_delta_updates_fields_and_appends_arguments() -> None: - tool_call = _tool_call_delta(tool_call_id="id", tool_type="function", function_name="x", function_arguments="{") - delta = _tool_call_delta(tool_call_id="id2", tool_type="function", function_name="y", function_arguments="}") - llm_module._merge_tool_call_delta(tool_call, delta) - assert tool_call.id == "id2" - assert tool_call.type == "function" - assert tool_call.function.name == "y" - assert tool_call.function.arguments == "{}" - - -def test_increase_tool_call_generates_id_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="fixed")) - delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{") - existing: list[AssistantPromptMessage.ToolCall] = [] - llm_module._increase_tool_call([delta], existing) - assert len(existing) == 1 - assert existing[0].id == "chatcmpl-tool-fixed" - assert existing[0].function.name == "fn" - assert existing[0].function.arguments == "{" - - -def test_increase_tool_call_merges_incremental_arguments() -> None: - existing: list[AssistantPromptMessage.ToolCall] = [] - llm_module._increase_tool_call( - [_tool_call_delta(tool_call_id="id", function_name="fn", function_arguments="{")], existing - ) - llm_module._increase_tool_call( - [_tool_call_delta(tool_call_id="id", function_name="", function_arguments="}")], existing - ) - assert len(existing) == 1 - assert existing[0].function.name == "fn" - assert existing[0].function.arguments == "{}" - - -@pytest.mark.parametrize( - ("content", "expected_type"), - [ - ("hello", str), - ([TextPromptMessageContent(data="hello")], list), - ], -) -def test_build_llm_result_from_chunks_accumulates_and_raises_error( - content: str | list[TextPromptMessageContent], - expected_type: type, - monkeypatch: pytest.MonkeyPatch, - caplog: pytest.LogCaptureFixture, -) -> None: - monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="drain")) - caplog.set_level(logging.DEBUG) - - tool_delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{}") - first = _chunk(content=content, tool_calls=[tool_delta], usage=_usage(3, 4), system_fingerprint="fp1") - - def iter_with_error() -> Iterator[LLMResultChunk]: - yield first - raise RuntimeError("drain boom") - - with pytest.raises(RuntimeError, match="drain boom"): - _build_llm_result_from_chunks( - model="m", prompt_messages=[UserPromptMessage(content="u")], chunks=iter_with_error() - ) - - assert any("Error while consuming non-stream plugin chunk iterator" in record.message for record in caplog.records) - - -def test_build_llm_result_from_chunks_empty_iterator() -> None: - def empty() -> Iterator[LLMResultChunk]: - if False: # pragma: no cover - yield _chunk() - return - - result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=empty()) - assert result.message.content == [] - assert result.usage.total_tokens == 0 - assert result.system_fingerprint is None - - -def test_build_llm_result_from_chunks_accumulates_all_chunks() -> None: - chunks = iter([_chunk(content="first"), _chunk(content="second")]) - result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=chunks) - assert result.message.content == "firstsecond" - - -def test_invoke_llm_via_runtime_passes_list_converted_stop(llm: _TestLLM) -> None: - llm.model_runtime = MagicMock() - prompt_messages: Sequence[PromptMessage] = (UserPromptMessage(content="hi"),) - result = llm_module._invoke_llm_via_runtime( - llm_model=llm, - provider="prov", - model="m", - credentials={"k": "v"}, - model_parameters={"temp": 1}, - prompt_messages=prompt_messages, - tools=None, - stop=("a", "b"), - stream=True, - ) - - llm.model_runtime.invoke_llm.assert_called_once_with( - provider="prov", - model="m", - credentials={"k": "v"}, - model_parameters={"temp": 1}, - prompt_messages=list(prompt_messages), - tools=None, - stop=("a", "b"), - stream=True, - ) - assert result is llm.model_runtime.invoke_llm.return_value - - -def test_normalize_non_stream_runtime_result_passthrough_llmresult() -> None: - llm_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()) - assert ( - llm_module._normalize_non_stream_runtime_result(model="m", prompt_messages=[], result=llm_result) is llm_result - ) - - -def test_normalize_non_stream_runtime_result_builds_from_chunks() -> None: - chunks = iter([_chunk(content="hello", usage=_usage(1, 1))]) - result = llm_module._normalize_non_stream_runtime_result( - model="m", prompt_messages=[UserPromptMessage(content="u")], result=chunks - ) - assert isinstance(result, LLMResult) - assert result.message.content == "hello" - - -def test_invoke_non_stream_normalizes_and_sets_prompt_messages(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: - plugin_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()) - monkeypatch.setattr( - "graphon.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime", - lambda **_: plugin_result, - ) - cb = SpyCallback() - prompt_messages = [UserPromptMessage(content="hi")] - result = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=False, callbacks=[cb]) - assert isinstance(result, LLMResult) - assert result.prompt_messages == prompt_messages - assert len(cb.before) == 1 - assert len(cb.after) == 1 - assert cb.after[0]["result"].prompt_messages == prompt_messages - - -def test_invoke_stream_wraps_generator_and_triggers_callbacks(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: - plugin_chunks = iter( - [ - _chunk(model="m1", content="a"), - _chunk( - model="m2", content=[TextPromptMessageContent(data="b")], usage=_usage(2, 3), system_fingerprint="fp" - ), - _chunk(model="m3", content=None), - ] - ) - monkeypatch.setattr( - "graphon.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime", - lambda **_: plugin_chunks, - ) - - cb = SpyCallback() - prompt_messages = [UserPromptMessage(content="hi")] - gen = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=True, callbacks=[cb]) - - assert isinstance(gen, Generator) - chunks = list(gen) - assert len(chunks) == 3 - assert all(chunk.prompt_messages == prompt_messages for chunk in chunks) - assert len(cb.before) == 1 - assert len(cb.new_chunk) == 3 - assert len(cb.after) == 1 - final_result: LLMResult = cb.after[0]["result"] - assert final_result.model == "m3" - assert final_result.system_fingerprint == "fp" - assert isinstance(final_result.message.content, list) - assert [c.data for c in final_result.message.content] == ["a", "b"] - assert final_result.usage.total_tokens == 5 - - -def test_invoke_triggers_error_callbacks_and_raises_transformed(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: - def boom(**_: Any) -> Any: - raise ValueError("plugin down") - - monkeypatch.setattr( - "graphon.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime", boom - ) - cb = SpyCallback() - with pytest.raises(RuntimeError, match="transformed: plugin down"): - llm.invoke( - model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False, callbacks=[cb] - ) - assert len(cb.error) == 1 - assert isinstance(cb.error[0]["ex"], ValueError) - - -def test_invoke_raises_not_implemented_for_unsupported_result_type( - llm: _TestLLM, monkeypatch: pytest.MonkeyPatch -) -> None: - monkeypatch.setattr(llm_module, "_invoke_llm_via_runtime", lambda **_: "not-a-result") - monkeypatch.setattr(llm_module, "_normalize_non_stream_runtime_result", lambda **_: "not-a-result") - with pytest.raises(NotImplementedError, match="unsupported invoke result type"): - llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False) - - -def test_invoke_appends_logging_callback_in_debug(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: - captured_callbacks: list[list[Callback]] = [] - - class FakeLoggingCallback(SpyCallback): - pass - - monkeypatch.setattr(llm_module, "LoggingCallback", FakeLoggingCallback) - monkeypatch.setattr(llm_module.logger, "isEnabledFor", lambda level: level == logging.DEBUG) - monkeypatch.setattr( - "graphon.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_runtime", - lambda **_: LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()), - ) - - original_trigger = llm._trigger_before_invoke_callbacks - - def spy_trigger(*args: Any, **kwargs: Any) -> None: - captured_callbacks.append(list(kwargs["callbacks"])) - original_trigger(*args, **kwargs) - - monkeypatch.setattr(llm, "_trigger_before_invoke_callbacks", spy_trigger) - llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False) - assert any(isinstance(cb, FakeLoggingCallback) for cb in captured_callbacks[0]) - - -def test_get_num_tokens_returns_0_when_runtime_returns_0(llm: _TestLLM) -> None: - llm.model_runtime.get_llm_num_tokens.return_value = 0 - assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 0 - - -def test_get_num_tokens_uses_runtime(llm: _TestLLM) -> None: - llm.model_runtime.get_llm_num_tokens.return_value = 42 - assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 42 - llm.model_runtime.get_llm_num_tokens.assert_called_once_with( - provider="provider", - model_type=ModelType.LLM, - model="m", - credentials={}, - prompt_messages=[UserPromptMessage(content="x")], - tools=None, - ) - - -def test_calc_response_usage_uses_prices_and_latency(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(llm_module.time, "perf_counter", lambda: 4.5) - llm.started_at = 1.0 - usage = llm.calc_response_usage(model="m", credentials={}, prompt_tokens=10, completion_tokens=5) - assert usage.total_tokens == 15 - assert usage.total_price == Decimal("0.15") - assert usage.latency == 3.5 - - -def test_invoke_result_generator_raises_transformed_on_iteration_error(llm: _TestLLM) -> None: - def broken() -> Iterator[LLMResultChunk]: - yield _chunk(content="ok") - raise ValueError("chunk stream broken") - - gen = llm._invoke_result_generator( - model="m", - result=broken(), - credentials={}, - prompt_messages=[UserPromptMessage(content="u")], - model_parameters={}, - callbacks=[SpyCallback()], - ) - - with pytest.raises(RuntimeError, match="transformed: chunk stream broken"): - list(gen) diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_moderation_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_moderation_model.py deleted file mode 100644 index a42a930806..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_moderation_model.py +++ /dev/null @@ -1,56 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.errors.invoke import InvokeError -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel - - -@pytest.fixture -def provider_schema() -> ProviderEntity: - return ProviderEntity( - provider="test_provider", - label=I18nObject(en_US="test_provider"), - supported_model_types=[ModelType.MODERATION], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -@pytest.fixture -def model_runtime() -> MagicMock: - return MagicMock() - - -@pytest.fixture -def moderation_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> ModerationModel: - return ModerationModel(provider_schema=provider_schema, model_runtime=model_runtime) - - -def test_model_type(moderation_model: ModerationModel) -> None: - assert moderation_model.model_type == ModelType.MODERATION - - -def test_invoke_success(moderation_model: ModerationModel, model_runtime: MagicMock) -> None: - with patch("time.perf_counter", return_value=1.0): - model_runtime.invoke_moderation.return_value = True - - result = moderation_model.invoke(model="test_model", credentials={"api_key": "abc"}, text="test text") - - assert result is True - assert moderation_model.started_at == 1.0 - model_runtime.invoke_moderation.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - text="test text", - ) - - -def test_invoke_exception(moderation_model: ModerationModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_moderation.side_effect = Exception("Test error") - - with pytest.raises(InvokeError, match="Test error"): - moderation_model.invoke(model="test_model", credentials={"api_key": "abc"}, text="test text") diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_rerank_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_rerank_model.py deleted file mode 100644 index 9650ed2db7..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_rerank_model.py +++ /dev/null @@ -1,110 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult -from graphon.model_runtime.errors.invoke import InvokeError -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel - - -@pytest.fixture -def provider_schema() -> ProviderEntity: - return ProviderEntity( - provider="test_provider", - label=I18nObject(en_US="test_provider"), - supported_model_types=[ModelType.RERANK], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -@pytest.fixture -def model_runtime() -> MagicMock: - return MagicMock() - - -@pytest.fixture -def rerank_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> RerankModel: - return RerankModel(provider_schema=provider_schema, model_runtime=model_runtime) - - -def test_model_type_is_rerank_by_default(rerank_model: RerankModel) -> None: - assert rerank_model.model_type == ModelType.RERANK - - -def test_invoke_calls_runtime_and_passes_args(rerank_model: RerankModel, model_runtime: MagicMock) -> None: - expected = RerankResult(model="rerank", docs=[RerankDocument(index=0, text="a", score=0.5)]) - model_runtime.invoke_rerank.return_value = expected - - result = rerank_model.invoke( - model="rerank", - credentials={"k": "v"}, - query="q", - docs=["d1", "d2"], - score_threshold=0.2, - top_n=10, - ) - - assert result == expected - model_runtime.invoke_rerank.assert_called_once_with( - provider="test_provider", - model="rerank", - credentials={"k": "v"}, - query="q", - docs=["d1", "d2"], - score_threshold=0.2, - top_n=10, - ) - - -def test_invoke_transforms_and_raises_on_runtime_error(rerank_model: RerankModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_rerank.side_effect = Exception("runtime down") - - with pytest.raises(InvokeError, match="runtime down"): - rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"]) - - -def test_invoke_multimodal_calls_runtime_and_passes_args(rerank_model: RerankModel, model_runtime: MagicMock) -> None: - expected = RerankResult(model="mm", docs=[RerankDocument(index=0, text="x", score=0.9)]) - model_runtime.invoke_multimodal_rerank.return_value = expected - - query = {"type": "text", "text": "q"} - docs = [{"type": "text", "text": "d1"}] - result = rerank_model.invoke_multimodal_rerank( - model="mm", - credentials={"k": "v"}, - query=query, - docs=docs, - score_threshold=None, - top_n=None, - ) - - assert result == expected - model_runtime.invoke_multimodal_rerank.assert_called_once_with( - provider="test_provider", - model="mm", - credentials={"k": "v"}, - query=query, - docs=docs, - score_threshold=None, - top_n=None, - ) - - -def test_invoke_multimodal_transforms_and_raises_on_runtime_error( - rerank_model: RerankModel, model_runtime: MagicMock -) -> None: - model_runtime.invoke_multimodal_rerank.side_effect = Exception("multimodal runtime down") - - query = {"content": "q", "content_type": "text"} - docs = [{"content": "d1", "content_type": "text"}] - - with pytest.raises(InvokeError, match="multimodal runtime down"): - rerank_model.invoke_multimodal_rerank( - model="mm", - credentials={}, - query=query, - docs=docs, - ) diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_runtime_user_forwarding.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_runtime_user_forwarding.py deleted file mode 100644 index 98bb1eb1b8..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_runtime_user_forwarding.py +++ /dev/null @@ -1,170 +0,0 @@ -from decimal import Decimal -from io import BytesIO -from unittest.mock import MagicMock - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.entities.rerank_entities import RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel - - -def _provider_schema(model_type: ModelType) -> ProviderEntity: - return ProviderEntity( - provider="langgenius/openai/openai", - provider_name="openai", - label=I18nObject(en_US="OpenAI"), - supported_model_types=[model_type], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -def _embedding_usage() -> EmbeddingUsage: - return EmbeddingUsage( - tokens=1, - total_tokens=1, - unit_price=Decimal(0), - price_unit=Decimal(0), - total_price=Decimal(0), - currency="USD", - latency=0.0, - ) - - -def test_large_language_model_invokes_runtime_without_user_id() -> None: - runtime = MagicMock() - runtime.invoke_llm.return_value = LLMResult( - model="gpt-4o-mini", - prompt_messages=[], - message=AssistantPromptMessage(content="ok"), - usage=LLMUsage.empty_usage(), - ) - model = LargeLanguageModel(provider_schema=_provider_schema(ModelType.LLM), model_runtime=runtime) - - model.invoke( - model="gpt-4o-mini", - credentials={"api_key": "secret"}, - prompt_messages=[UserPromptMessage(content="hi")], - stream=False, - ) - - assert "user_id" not in runtime.invoke_llm.call_args.kwargs - - -def test_text_embedding_model_invokes_runtime_without_user_id_for_text_requests() -> None: - runtime = MagicMock() - runtime.invoke_text_embedding.return_value = EmbeddingResult( - model="text-embedding-3-small", - embeddings=[[0.1]], - usage=_embedding_usage(), - ) - model = TextEmbeddingModel(provider_schema=_provider_schema(ModelType.TEXT_EMBEDDING), model_runtime=runtime) - - model.invoke( - model="text-embedding-3-small", - credentials={"api_key": "secret"}, - texts=["hello"], - ) - - assert "user_id" not in runtime.invoke_text_embedding.call_args.kwargs - - -def test_text_embedding_model_invokes_runtime_without_user_id_for_multimodal_requests() -> None: - runtime = MagicMock() - runtime.invoke_multimodal_embedding.return_value = EmbeddingResult( - model="text-embedding-3-small", - embeddings=[[0.1]], - usage=_embedding_usage(), - ) - model = TextEmbeddingModel(provider_schema=_provider_schema(ModelType.TEXT_EMBEDDING), model_runtime=runtime) - - model.invoke( - model="text-embedding-3-small", - credentials={"api_key": "secret"}, - multimodel_documents=[{"content": "hello", "content_type": "text"}], - ) - - assert "user_id" not in runtime.invoke_multimodal_embedding.call_args.kwargs - - -def test_rerank_model_invokes_runtime_without_user_id_for_text_requests() -> None: - runtime = MagicMock() - runtime.invoke_rerank.return_value = RerankResult(model="rerank", docs=[]) - model = RerankModel(provider_schema=_provider_schema(ModelType.RERANK), model_runtime=runtime) - - model.invoke( - model="rerank", - credentials={"api_key": "secret"}, - query="q", - docs=["d1"], - ) - - assert "user_id" not in runtime.invoke_rerank.call_args.kwargs - - -def test_rerank_model_invokes_runtime_without_user_id_for_multimodal_requests() -> None: - runtime = MagicMock() - runtime.invoke_multimodal_rerank.return_value = RerankResult(model="rerank", docs=[]) - model = RerankModel(provider_schema=_provider_schema(ModelType.RERANK), model_runtime=runtime) - - model.invoke_multimodal_rerank( - model="rerank", - credentials={"api_key": "secret"}, - query={"content": "q", "content_type": "text"}, - docs=[{"content": "d1", "content_type": "text"}], - ) - - assert "user_id" not in runtime.invoke_multimodal_rerank.call_args.kwargs - - -def test_tts_model_invokes_runtime_without_user_id() -> None: - runtime = MagicMock() - runtime.invoke_tts.return_value = [b"chunk"] - model = TTSModel(provider_schema=_provider_schema(ModelType.TTS), model_runtime=runtime) - - list( - model.invoke( - model="tts-1", - credentials={"api_key": "secret"}, - content_text="hello", - voice="alloy", - ) - ) - - assert "user_id" not in runtime.invoke_tts.call_args.kwargs - - -def test_speech_to_text_model_invokes_runtime_without_user_id() -> None: - runtime = MagicMock() - runtime.invoke_speech_to_text.return_value = "transcript" - model = Speech2TextModel(provider_schema=_provider_schema(ModelType.SPEECH2TEXT), model_runtime=runtime) - - model.invoke( - model="whisper-1", - credentials={"api_key": "secret"}, - file=BytesIO(b"audio"), - ) - - assert "user_id" not in runtime.invoke_speech_to_text.call_args.kwargs - - -def test_moderation_model_invokes_runtime_without_user_id() -> None: - runtime = MagicMock() - runtime.invoke_moderation.return_value = True - model = ModerationModel(provider_schema=_provider_schema(ModelType.MODERATION), model_runtime=runtime) - - model.invoke( - model="omni-moderation-latest", - credentials={"api_key": "secret"}, - text="unsafe?", - ) - - assert "user_id" not in runtime.invoke_moderation.call_args.kwargs diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_speech2text_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_speech2text_model.py deleted file mode 100644 index b03923bbc2..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_speech2text_model.py +++ /dev/null @@ -1,56 +0,0 @@ -from io import BytesIO -from unittest.mock import MagicMock - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.errors.invoke import InvokeError -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel - - -@pytest.fixture -def provider_schema() -> ProviderEntity: - return ProviderEntity( - provider="test_provider", - label=I18nObject(en_US="test_provider"), - supported_model_types=[ModelType.SPEECH2TEXT], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -@pytest.fixture -def model_runtime() -> MagicMock: - return MagicMock() - - -@pytest.fixture -def speech2text_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> Speech2TextModel: - return Speech2TextModel(provider_schema=provider_schema, model_runtime=model_runtime) - - -def test_model_type(speech2text_model: Speech2TextModel) -> None: - assert speech2text_model.model_type == ModelType.SPEECH2TEXT - - -def test_invoke_success(speech2text_model: Speech2TextModel, model_runtime: MagicMock) -> None: - file = BytesIO(b"audio data") - model_runtime.invoke_speech_to_text.return_value = "transcribed text" - - result = speech2text_model.invoke(model="test_model", credentials={"api_key": "abc"}, file=file) - - assert result == "transcribed text" - model_runtime.invoke_speech_to_text.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - file=file, - ) - - -def test_invoke_exception(speech2text_model: Speech2TextModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_speech_to_text.side_effect = Exception("Test error") - - with pytest.raises(InvokeError, match="Test error"): - speech2text_model.invoke(model="test_model", credentials={"api_key": "abc"}, file=BytesIO(b"audio data")) diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_text_embedding_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_text_embedding_model.py deleted file mode 100644 index 64caf3a315..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_text_embedding_model.py +++ /dev/null @@ -1,146 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult -from graphon.model_runtime.errors.invoke import InvokeError -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel - - -@pytest.fixture -def provider_schema() -> ProviderEntity: - return ProviderEntity( - provider="test_provider", - label=I18nObject(en_US="test_provider"), - supported_model_types=[ModelType.TEXT_EMBEDDING], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -@pytest.fixture -def model_runtime() -> MagicMock: - return MagicMock() - - -@pytest.fixture -def text_embedding_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> TextEmbeddingModel: - return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=model_runtime) - - -def test_model_type(text_embedding_model: TextEmbeddingModel) -> None: - assert text_embedding_model.model_type == ModelType.TEXT_EMBEDDING - - -def test_invoke_with_texts(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: - expected_result = MagicMock(spec=EmbeddingResult) - model_runtime.invoke_text_embedding.return_value = expected_result - - result = text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}, texts=["hello", "world"]) - - assert result == expected_result - model_runtime.invoke_text_embedding.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - texts=["hello", "world"], - input_type=EmbeddingInputType.DOCUMENT, - ) - - -def test_invoke_with_multimodal_documents(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: - expected_result = MagicMock(spec=EmbeddingResult) - model_runtime.invoke_multimodal_embedding.return_value = expected_result - - result = text_embedding_model.invoke( - model="test_model", - credentials={"api_key": "abc"}, - multimodel_documents=[{"type": "text", "text": "hello"}], - ) - - assert result == expected_result - model_runtime.invoke_multimodal_embedding.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - documents=[{"type": "text", "text": "hello"}], - input_type=EmbeddingInputType.DOCUMENT, - ) - - -def test_invoke_no_input(text_embedding_model: TextEmbeddingModel) -> None: - with pytest.raises(ValueError, match="No texts or files provided"): - text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}) - - -def test_invoke_prefers_texts_over_multimodal_documents( - text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock -) -> None: - expected_result = MagicMock(spec=EmbeddingResult) - model_runtime.invoke_text_embedding.return_value = expected_result - - result = text_embedding_model.invoke( - model="test_model", - credentials={"api_key": "abc"}, - texts=["hello"], - multimodel_documents=[{"type": "text", "text": "world"}], - ) - - assert result == expected_result - model_runtime.invoke_text_embedding.assert_called_once() - model_runtime.invoke_multimodal_embedding.assert_not_called() - - -def test_invoke_exception(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_text_embedding.side_effect = Exception("Test error") - - with pytest.raises(InvokeError, match="Test error"): - text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}, texts=["hello"]) - - -def test_get_num_tokens(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: - model_runtime.get_text_embedding_num_tokens.return_value = [1, 1] - - result = text_embedding_model.get_num_tokens( - model="test_model", credentials={"api_key": "abc"}, texts=["hello", "world"] - ) - - assert result == [1, 1] - model_runtime.get_text_embedding_num_tokens.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - texts=["hello", "world"], - ) - - -def test_get_context_size(text_embedding_model: TextEmbeddingModel) -> None: - mock_schema = MagicMock() - mock_schema.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 2048} - - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 2048 - - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): - assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 1000 - - mock_schema.model_properties = {} - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 1000 - - -def test_get_max_chunks(text_embedding_model: TextEmbeddingModel) -> None: - mock_schema = MagicMock() - mock_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10} - - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 10 - - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): - assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 1 - - mock_schema.model_properties = {} - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 1 diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_tts_model.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_tts_model.py deleted file mode 100644 index d15efb69c3..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/test_tts_model.py +++ /dev/null @@ -1,83 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from graphon.model_runtime.errors.invoke import InvokeError -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel - - -@pytest.fixture -def provider_schema() -> ProviderEntity: - return ProviderEntity( - provider="test_provider", - label=I18nObject(en_US="test_provider"), - supported_model_types=[ModelType.TTS], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - ) - - -@pytest.fixture -def model_runtime() -> MagicMock: - return MagicMock() - - -@pytest.fixture -def tts_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> TTSModel: - return TTSModel(provider_schema=provider_schema, model_runtime=model_runtime) - - -def test_model_type(tts_model: TTSModel) -> None: - assert tts_model.model_type == ModelType.TTS - - -def test_invoke_success(tts_model: TTSModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_tts.return_value = [b"audio_chunk"] - - result = tts_model.invoke( - model="test_model", - credentials={"api_key": "abc"}, - content_text="Hello world", - voice="alloy", - ) - - assert list(result) == [b"audio_chunk"] - model_runtime.invoke_tts.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - content_text="Hello world", - voice="alloy", - ) - - -def test_invoke_exception(tts_model: TTSModel, model_runtime: MagicMock) -> None: - model_runtime.invoke_tts.side_effect = Exception("Test error") - - with pytest.raises(InvokeError, match="Test error"): - tts_model.invoke( - model="test_model", - credentials={"api_key": "abc"}, - content_text="Hello world", - voice="alloy", - ) - - -def test_get_tts_model_voices(tts_model: TTSModel, model_runtime: MagicMock) -> None: - model_runtime.get_tts_model_voices.return_value = [{"name": "Voice1"}] - - result = tts_model.get_tts_model_voices( - model="test_model", - credentials={"api_key": "abc"}, - language="en-US", - ) - - assert result == [{"name": "Voice1"}] - model_runtime.get_tts_model_voices.assert_called_once_with( - provider="test_provider", - model="test_model", - credentials={"api_key": "abc"}, - language="en-US", - ) diff --git a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py b/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py deleted file mode 100644 index d4d3eeb18c..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py +++ /dev/null @@ -1,96 +0,0 @@ -from unittest.mock import MagicMock, patch - -import graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer as gpt2_tokenizer_module -from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer - - -class TestGPT2Tokenizer: - def setup_method(self): - # Reset the global tokenizer before each test to ensure we test initialization - gpt2_tokenizer_module._tokenizer = None - - def test_get_encoder_tiktoken(self): - """ - Test that get_encoder successfully uses tiktoken when available. - """ - mock_encoding = MagicMock() - # Mock tiktoken to be sure it's used - with patch("tiktoken.get_encoding", return_value=mock_encoding) as mock_get_encoding: - encoder = GPT2Tokenizer.get_encoder() - assert encoder == mock_encoding - mock_get_encoding.assert_called_once_with("gpt2") - - # Verify singleton behavior within the same test - encoder2 = GPT2Tokenizer.get_encoder() - assert encoder2 is encoder - assert mock_get_encoding.call_count == 1 - - def test_get_encoder_tiktoken_fallback(self): - """ - Test that get_encoder falls back to transformers when tiktoken fails. - """ - # patch tiktoken.get_encoding to raise an exception - with patch("tiktoken.get_encoding", side_effect=Exception("Tiktoken failure")): - # patch transformers.GPT2Tokenizer - with patch("transformers.GPT2Tokenizer.from_pretrained") as mock_from_pretrained: - mock_transformer_tokenizer = MagicMock() - mock_from_pretrained.return_value = mock_transformer_tokenizer - - with patch( - "graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer.logger" - ) as mock_logger: - encoder = GPT2Tokenizer.get_encoder() - - assert encoder == mock_transformer_tokenizer - mock_from_pretrained.assert_called_once() - mock_logger.info.assert_called_once_with("Fallback to Transformers' GPT-2 tokenizer from tiktoken") - - def test_get_num_tokens(self): - """ - Test get_num_tokens returns the correct count. - """ - mock_encoder = MagicMock() - mock_encoder.encode.return_value = [1, 2, 3, 4, 5] - - with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder): - tokens_count = GPT2Tokenizer.get_num_tokens("test text") - assert tokens_count == 5 - mock_encoder.encode.assert_called_once_with("test text") - - def test_get_num_tokens_by_gpt2_direct(self): - """ - Test _get_num_tokens_by_gpt2 directly. - """ - mock_encoder = MagicMock() - mock_encoder.encode.return_value = [1, 2] - - with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder): - tokens_count = GPT2Tokenizer._get_num_tokens_by_gpt2("hello") - assert tokens_count == 2 - mock_encoder.encode.assert_called_once_with("hello") - - def test_get_encoder_already_initialized(self): - """ - Test that if _tokenizer is already set, it returns it immediately. - """ - mock_existing_tokenizer = MagicMock() - gpt2_tokenizer_module._tokenizer = mock_existing_tokenizer - - # Tiktoken should not be called if already initialized - with patch("tiktoken.get_encoding") as mock_get_encoding: - encoder = GPT2Tokenizer.get_encoder() - assert encoder == mock_existing_tokenizer - mock_get_encoding.assert_not_called() - - def test_get_encoder_thread_safety(self): - """ - Simple test to ensure the lock is used. - """ - mock_encoding = MagicMock() - with patch("tiktoken.get_encoding", return_value=mock_encoding): - # We patch the lock in the module - with patch("graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer._lock") as mock_lock: - encoder = GPT2Tokenizer.get_encoder() - assert encoder == mock_encoding - mock_lock.__enter__.assert_called_once() - mock_lock.__exit__.assert_called_once() diff --git a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_common_validator.py b/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_common_validator.py deleted file mode 100644 index 60ded4b90a..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_common_validator.py +++ /dev/null @@ -1,201 +0,0 @@ -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.provider_entities import ( - CredentialFormSchema, - FormOption, - FormShowOnObject, - FormType, -) -from graphon.model_runtime.schema_validators.common_validator import CommonValidator - - -class TestCommonValidator: - def test_validate_credential_form_schema_required_missing(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True - ) - with pytest.raises(ValueError, match="Variable api_key is required"): - validator._validate_credential_form_schema(schema, {}) - - def test_validate_credential_form_schema_not_required_missing_with_default(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key"), - type=FormType.TEXT_INPUT, - required=False, - default="default_value", - ) - assert validator._validate_credential_form_schema(schema, {}) == "default_value" - - def test_validate_credential_form_schema_not_required_missing_no_default(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=False - ) - assert validator._validate_credential_form_schema(schema, {}) is None - - def test_validate_credential_form_schema_max_length_exceeded(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, max_length=5 - ) - with pytest.raises(ValueError, match="Variable api_key length should not be greater than 5"): - validator._validate_credential_form_schema(schema, {"api_key": "123456"}) - - def test_validate_credential_form_schema_not_string(self): - validator = CommonValidator() - schema = CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT) - with pytest.raises(ValueError, match="Variable api_key should be string"): - validator._validate_credential_form_schema(schema, {"api_key": 123}) - - def test_validate_credential_form_schema_select_invalid_option(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="mode", - label=I18nObject(en_US="Mode"), - type=FormType.SELECT, - options=[ - FormOption(label=I18nObject(en_US="Fast"), value="fast"), - FormOption(label=I18nObject(en_US="Slow"), value="slow"), - ], - ) - with pytest.raises(ValueError, match="Variable mode is not in options"): - validator._validate_credential_form_schema(schema, {"mode": "medium"}) - - def test_validate_credential_form_schema_select_valid_option(self): - validator = CommonValidator() - schema = CredentialFormSchema( - variable="mode", - label=I18nObject(en_US="Mode"), - type=FormType.SELECT, - options=[ - FormOption(label=I18nObject(en_US="Fast"), value="fast"), - FormOption(label=I18nObject(en_US="Slow"), value="slow"), - ], - ) - assert validator._validate_credential_form_schema(schema, {"mode": "fast"}) == "fast" - - def test_validate_credential_form_schema_switch_invalid(self): - validator = CommonValidator() - schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH) - with pytest.raises(ValueError, match="Variable enabled should be true or false"): - validator._validate_credential_form_schema(schema, {"enabled": "maybe"}) - - def test_validate_credential_form_schema_switch_valid(self): - validator = CommonValidator() - schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH) - assert validator._validate_credential_form_schema(schema, {"enabled": "true"}) is True - assert validator._validate_credential_form_schema(schema, {"enabled": "FALSE"}) is False - - def test_validate_and_filter_credential_form_schemas_with_show_on(self): - validator = CommonValidator() - schemas = [ - CredentialFormSchema( - variable="auth_type", - label=I18nObject(en_US="Auth Type"), - type=FormType.SELECT, - options=[ - FormOption(label=I18nObject(en_US="API Key"), value="api_key"), - FormOption(label=I18nObject(en_US="OAuth"), value="oauth"), - ], - ), - CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key"), - type=FormType.TEXT_INPUT, - show_on=[FormShowOnObject(variable="auth_type", value="api_key")], - ), - CredentialFormSchema( - variable="client_id", - label=I18nObject(en_US="Client ID"), - type=FormType.TEXT_INPUT, - show_on=[FormShowOnObject(variable="auth_type", value="oauth")], - ), - ] - - # Case 1: auth_type = api_key - credentials = {"auth_type": "api_key", "api_key": "my_secret"} - result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) - assert "auth_type" in result - assert "api_key" in result - assert "client_id" not in result - assert result["api_key"] == "my_secret" - - # Case 2: auth_type = oauth - credentials = {"auth_type": "oauth", "client_id": "my_client"} - result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) - # Note: 'auth_type' contains 'oauth'. 'result' contains keys that pass validation. - # Since 'oauth' is not an empty string, it is in result. - assert "auth_type" in result - assert "api_key" not in result - assert "client_id" in result - assert result["client_id"] == "my_client" - - def test_validate_and_filter_show_on_missing_variable(self): - validator = CommonValidator() - schemas = [ - CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key"), - type=FormType.TEXT_INPUT, - show_on=[FormShowOnObject(variable="auth_type", value="api_key")], - ) - ] - # auth_type is missing in credentials, so api_key should be filtered out - result = validator._validate_and_filter_credential_form_schemas(schemas, {}) - assert result == {} - - def test_validate_and_filter_show_on_mismatch_value(self): - validator = CommonValidator() - schemas = [ - CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key"), - type=FormType.TEXT_INPUT, - show_on=[FormShowOnObject(variable="auth_type", value="api_key")], - ) - ] - # auth_type is oauth, which doesn't match show_on - result = validator._validate_and_filter_credential_form_schemas(schemas, {"auth_type": "oauth"}) - assert result == {} - - def test_validate_and_filter_multiple_show_on(self): - validator = CommonValidator() - schemas = [ - CredentialFormSchema( - variable="target", - label=I18nObject(en_US="Target"), - type=FormType.TEXT_INPUT, - show_on=[FormShowOnObject(variable="v1", value="a"), FormShowOnObject(variable="v2", value="b")], - ) - ] - # Both match - assert "target" in validator._validate_and_filter_credential_form_schemas( - schemas, {"v1": "a", "v2": "b", "target": "val"} - ) - # One mismatch - assert "target" not in validator._validate_and_filter_credential_form_schemas( - schemas, {"v1": "a", "v2": "c", "target": "val"} - ) - # One missing - assert "target" not in validator._validate_and_filter_credential_form_schemas( - schemas, {"v1": "a", "target": "val"} - ) - - def test_validate_and_filter_skips_falsy_results(self): - validator = CommonValidator() - schemas = [ - CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH), - CredentialFormSchema( - variable="empty_str", label=I18nObject(en_US="Empty"), type=FormType.TEXT_INPUT, required=False - ), - ] - # Result of false switch is False. if result: is false. Not added. - # Result of empty string is "", if result: is false. Not added. - credentials = {"enabled": "false", "empty_str": ""} - result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) - assert "enabled" not in result - assert "empty_str" not in result diff --git a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_model_credential_schema_validator.py b/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_model_credential_schema_validator.py deleted file mode 100644 index 3932844b91..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_model_credential_schema_validator.py +++ /dev/null @@ -1,233 +0,0 @@ -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - CredentialFormSchema, - FieldModelSchema, - FormOption, - FormShowOnObject, - FormType, - ModelCredentialSchema, -) -from graphon.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator - - -def test_validate_and_filter_with_none_schema(): - validator = ModelCredentialSchemaValidator(ModelType.LLM, None) - with pytest.raises(ValueError, match="Model credential schema is None"): - validator.validate_and_filter({}) - - -def test_validate_and_filter_success(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key", zh_Hans="API Key"), - type=FormType.SECRET_INPUT, - required=True, - ), - CredentialFormSchema( - variable="optional_field", - label=I18nObject(en_US="Optional", zh_Hans="可选"), - type=FormType.TEXT_INPUT, - required=False, - default="default_val", - ), - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - credentials = {"api_key": "sk-123456"} - result = validator.validate_and_filter(credentials) - - assert result["api_key"] == "sk-123456" - assert result["optional_field"] == "default_val" - assert credentials["__model_type"] == ModelType.LLM.value - - -def test_validate_and_filter_with_show_on(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="mode", label=I18nObject(en_US="Mode", zh_Hans="模式"), type=FormType.TEXT_INPUT, required=True - ), - CredentialFormSchema( - variable="conditional_field", - label=I18nObject(en_US="Conditional", zh_Hans="条件"), - type=FormType.TEXT_INPUT, - required=True, - show_on=[FormShowOnObject(variable="mode", value="advanced")], - ), - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - # mode is 'simple', conditional_field should be filtered out - credentials = {"mode": "simple", "conditional_field": "secret"} - result = validator.validate_and_filter(credentials) - assert "conditional_field" not in result - assert result["mode"] == "simple" - - # mode is 'advanced', conditional_field should be kept - credentials = {"mode": "advanced", "conditional_field": "secret"} - result = validator.validate_and_filter(credentials) - assert result["conditional_field"] == "secret" - assert result["mode"] == "advanced" - - # show_on variable missing in credentials - credentials = {"conditional_field": "secret"} # mode missing - with pytest.raises(ValueError, match="Variable mode is required"): # because mode is required in schema - validator.validate_and_filter(credentials) - - -def test_validate_and_filter_show_on_missing_trigger_var(): - # specifically test all_show_on_match = False when variable not in credentials - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="optional_trigger", - label=I18nObject(en_US="Optional Trigger", zh_Hans="可选触发"), - type=FormType.TEXT_INPUT, - required=False, - ), - CredentialFormSchema( - variable="conditional_field", - label=I18nObject(en_US="Conditional", zh_Hans="条件"), - type=FormType.TEXT_INPUT, - required=False, - show_on=[FormShowOnObject(variable="optional_trigger", value="active")], - ), - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - # optional_trigger missing, conditional_field should be skipped - result = validator.validate_and_filter({"conditional_field": "val"}) - assert "conditional_field" not in result - - -def test_common_validator_logic_required(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="api_key", - label=I18nObject(en_US="API Key", zh_Hans="API Key"), - type=FormType.SECRET_INPUT, - required=True, - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - with pytest.raises(ValueError, match="Variable api_key is required"): - validator.validate_and_filter({}) - - with pytest.raises(ValueError, match="Variable api_key is required"): - validator.validate_and_filter({"api_key": ""}) - - -def test_common_validator_logic_max_length(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="key", - label=I18nObject(en_US="Key", zh_Hans="Key"), - type=FormType.TEXT_INPUT, - required=True, - max_length=5, - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - with pytest.raises(ValueError, match="Variable key length should not be greater than 5"): - validator.validate_and_filter({"key": "123456"}) - - -def test_common_validator_logic_invalid_type(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="key", label=I18nObject(en_US="Key", zh_Hans="Key"), type=FormType.TEXT_INPUT, required=True - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - with pytest.raises(ValueError, match="Variable key should be string"): - validator.validate_and_filter({"key": 123}) - - -def test_common_validator_logic_switch(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="enabled", - label=I18nObject(en_US="Enabled", zh_Hans="启用"), - type=FormType.SWITCH, - required=True, - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - result = validator.validate_and_filter({"enabled": "true"}) - assert result["enabled"] is True - - result = validator.validate_and_filter({"enabled": "false"}) - assert "enabled" not in result - - with pytest.raises(ValueError, match="Variable enabled should be true or false"): - validator.validate_and_filter({"enabled": "not_a_bool"}) - - -def test_common_validator_logic_options(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="choice", - label=I18nObject(en_US="Choice", zh_Hans="选择"), - type=FormType.SELECT, - required=True, - options=[ - FormOption(label=I18nObject(en_US="A", zh_Hans="A"), value="a"), - FormOption(label=I18nObject(en_US="B", zh_Hans="B"), value="b"), - ], - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - result = validator.validate_and_filter({"choice": "a"}) - assert result["choice"] == "a" - - with pytest.raises(ValueError, match="Variable choice is not in options"): - validator.validate_and_filter({"choice": "c"}) - - -def test_validate_and_filter_optional_no_default(): - schema = ModelCredentialSchema( - model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), - credential_form_schemas=[ - CredentialFormSchema( - variable="optional", - label=I18nObject(en_US="Optional", zh_Hans="可选"), - type=FormType.TEXT_INPUT, - required=False, - ) - ], - ) - validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) - - result = validator.validate_and_filter({}) - assert "optional" not in result diff --git a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_provider_credential_schema_validator.py b/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_provider_credential_schema_validator.py deleted file mode 100644 index f7a2a5b623..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/schema_validators/test_provider_credential_schema_validator.py +++ /dev/null @@ -1,72 +0,0 @@ -import pytest - -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderCredentialSchema -from graphon.model_runtime.schema_validators.provider_credential_schema_validator import ( - ProviderCredentialSchemaValidator, -) - - -class TestProviderCredentialSchemaValidator: - def test_validate_and_filter_success(self): - # Setup schema - schema = ProviderCredentialSchema( - credential_form_schemas=[ - CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True - ), - CredentialFormSchema( - variable="endpoint", - label=I18nObject(en_US="Endpoint"), - type=FormType.TEXT_INPUT, - required=False, - default="https://api.example.com", - ), - ] - ) - validator = ProviderCredentialSchemaValidator(schema) - - # Test valid credentials - credentials = {"api_key": "my-secret-key"} - result = validator.validate_and_filter(credentials) - - assert result == {"api_key": "my-secret-key", "endpoint": "https://api.example.com"} - - def test_validate_and_filter_missing_required(self): - # Setup schema - schema = ProviderCredentialSchema( - credential_form_schemas=[ - CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True - ) - ] - ) - validator = ProviderCredentialSchemaValidator(schema) - - # Test missing required credentials - with pytest.raises(ValueError, match="Variable api_key is required"): - validator.validate_and_filter({}) - - def test_validate_and_filter_extra_fields_filtered(self): - # Setup schema - schema = ProviderCredentialSchema( - credential_form_schemas=[ - CredentialFormSchema( - variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True - ) - ] - ) - validator = ProviderCredentialSchemaValidator(schema) - - # Test credentials with extra fields - credentials = {"api_key": "my-secret-key", "extra_field": "should-be-filtered"} - result = validator.validate_and_filter(credentials) - - assert "api_key" in result - assert "extra_field" not in result - assert result == {"api_key": "my-secret-key"} - - def test_init(self): - schema = ProviderCredentialSchema(credential_form_schemas=[]) - validator = ProviderCredentialSchemaValidator(schema) - assert validator.provider_credential_schema == schema diff --git a/api/tests/unit_tests/graphon/model_runtime/utils/test_encoders.py b/api/tests/unit_tests/graphon/model_runtime/utils/test_encoders.py deleted file mode 100644 index 8edc143fae..0000000000 --- a/api/tests/unit_tests/graphon/model_runtime/utils/test_encoders.py +++ /dev/null @@ -1,231 +0,0 @@ -import dataclasses -import datetime -from collections import deque -from decimal import Decimal -from enum import Enum -from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network -from pathlib import Path, PurePath -from re import compile -from typing import Any -from unittest.mock import MagicMock -from uuid import UUID - -import pytest -from pydantic import BaseModel, ConfigDict -from pydantic.networks import AnyUrl, NameEmail -from pydantic.types import SecretBytes, SecretStr -from pydantic_core import Url -from pydantic_extra_types.color import Color - -from graphon.model_runtime.utils.encoders import ( - _model_dump, - decimal_encoder, - generate_encoders_by_class_tuples, - isoformat, - jsonable_encoder, -) - - -class MockEnum(Enum): - A = "a" - B = "b" - - -class MockPydanticModel(BaseModel): - model_config = ConfigDict(populate_by_name=True) - name: str - age: int - - -@dataclasses.dataclass -class MockDataclass: - name: str - value: Any - - -class MockWithDict: - def __init__(self, data): - self.data = data - - def __iter__(self): - return iter(self.data.items()) - - -class MockWithVars: - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - - -class TestEncoders: - def test_model_dump(self): - model = MockPydanticModel(name="test", age=20) - result = _model_dump(model) - assert result == {"name": "test", "age": 20} - - def test_isoformat(self): - d = datetime.date(2023, 1, 1) - assert isoformat(d) == "2023-01-01" - t = datetime.time(12, 0, 0) - assert isoformat(t) == "12:00:00" - - def test_decimal_encoder(self): - assert decimal_encoder(Decimal("1.0")) == 1.0 - assert decimal_encoder(Decimal(1)) == 1 - assert decimal_encoder(Decimal("1.5")) == 1.5 - assert decimal_encoder(Decimal(0)) == 0 - assert decimal_encoder(Decimal(-1)) == -1 - - def test_generate_encoders_by_class_tuples(self): - type_map = {int: str, float: str, str: int} - result = generate_encoders_by_class_tuples(type_map) - assert result[str] == (int, float) - assert result[int] == (str,) - - def test_jsonable_encoder_basic_types(self): - assert jsonable_encoder("string") == "string" - assert jsonable_encoder(123) == 123 - assert jsonable_encoder(1.23) == 1.23 - assert jsonable_encoder(None) is None - - def test_jsonable_encoder_pydantic(self): - model = MockPydanticModel(name="test", age=20) - assert jsonable_encoder(model) == {"name": "test", "age": 20} - - def test_jsonable_encoder_pydantic_root(self): - # Manually create a mock that behaves like a model with __root__ - # because Pydantic v2 handles root differently, but the code checks for "__root__" - model = MagicMock(spec=BaseModel) - # _model_dump(obj, mode="json", ...) -> model.model_dump(mode="json", ...) - model.model_dump.return_value = {"__root__": [1, 2, 3]} - assert jsonable_encoder(model) == [1, 2, 3] - - def test_jsonable_encoder_dataclass(self): - obj = MockDataclass(name="test", value=1) - assert jsonable_encoder(obj) == {"name": "test", "value": 1} - # Test dataclass type (should not be treated as instance) - # It should fall back to vars() or dict() or at least not crash - with pytest.raises(ValueError): - jsonable_encoder(MockDataclass) - - def test_jsonable_encoder_enum(self): - assert jsonable_encoder(MockEnum.A) == "a" - - def test_jsonable_encoder_path(self): - assert jsonable_encoder(Path("/tmp/test")) == "/tmp/test" - assert jsonable_encoder(PurePath("/tmp/test")) == "/tmp/test" - - def test_jsonable_encoder_decimal(self): - # In jsonable_encoder, Decimal is formatted as string via format(obj, "f") - assert jsonable_encoder(Decimal("1.23")) == "1.23" - assert jsonable_encoder(Decimal("1.000")) == "1.000" - - def test_jsonable_encoder_dict(self): - d = {"a": 1, "b": [2, 3], "_private": "hidden"} - assert jsonable_encoder(d) == {"a": 1, "b": [2, 3], "_private": "hidden"} - assert jsonable_encoder(d, excluded_key_prefixes=("_",)) == {"a": 1, "b": [2, 3]} - - d_with_none = {"a": 1, "b": None} - assert jsonable_encoder(d_with_none, exclude_none=True) == {"a": 1} - assert jsonable_encoder(d_with_none, exclude_none=False) == {"a": 1, "b": None} - - def test_jsonable_encoder_collections(self): - assert jsonable_encoder([1, 2]) == [1, 2] - assert jsonable_encoder((1, 2)) == [1, 2] - assert jsonable_encoder({1, 2}) == [1, 2] - assert jsonable_encoder(frozenset([1, 2])) == [1, 2] - assert jsonable_encoder(deque([1, 2])) == [1, 2] - - def gen(): - yield 1 - yield 2 - - assert jsonable_encoder(gen()) == [1, 2] - - def test_jsonable_encoder_custom_encoder(self): - custom = {int: lambda x: str(x + 1)} - assert jsonable_encoder(1, custom_encoder=custom) == "2" - - # Test subclass matching for custom encoder - class SubInt(int): - pass - - assert jsonable_encoder(SubInt(1), custom_encoder=custom) == "2" - - def test_jsonable_encoder_special_types(self): - # These hit ENCODERS_BY_TYPE or encoders_by_class_tuples - assert jsonable_encoder(b"bytes") == "bytes" - assert jsonable_encoder(Color("red")) == "red" - - dt = datetime.datetime(2023, 1, 1, 12, 0, 0) - assert jsonable_encoder(dt) == dt.isoformat() - - date = datetime.date(2023, 1, 1) - assert jsonable_encoder(date) == date.isoformat() - - time = datetime.time(12, 0, 0) - assert jsonable_encoder(time) == time.isoformat() - - td = datetime.timedelta(seconds=60) - assert jsonable_encoder(td) == 60.0 - - assert jsonable_encoder(IPv4Address("127.0.0.1")) == "127.0.0.1" - assert jsonable_encoder(IPv4Interface("127.0.0.1/24")) == "127.0.0.1/24" - assert jsonable_encoder(IPv4Network("127.0.0.0/24")) == "127.0.0.0/24" - assert jsonable_encoder(IPv6Address("::1")) == "::1" - assert jsonable_encoder(IPv6Interface("::1/128")) == "::1/128" - assert jsonable_encoder(IPv6Network("::/128")) == "::/128" - - assert jsonable_encoder(NameEmail(name="test", email="test@example.com")) == "test " - - assert jsonable_encoder(compile("abc")) == "abc" - - # Secret types - # Check what they actually return in this environment - res_bytes = jsonable_encoder(SecretBytes(b"secret")) - assert "**********" in res_bytes - - res_str = jsonable_encoder(SecretStr("secret")) - assert res_str == "**********" - - u = UUID("12345678-1234-5678-1234-567812345678") - assert jsonable_encoder(u) == str(u) - - url = AnyUrl("https://example.com") - assert jsonable_encoder(url) == "https://example.com/" - - purl = Url("https://example.com") - assert jsonable_encoder(purl) == "https://example.com/" - - def test_jsonable_encoder_fallback(self): - # dict(obj) success - obj_dict = MockWithDict({"a": 1}) - assert jsonable_encoder(obj_dict) == {"a": 1} - - # vars(obj) success - obj_vars = MockWithVars(x=10, y=20) - assert jsonable_encoder(obj_vars) == {"x": 10, "y": 20} - - # error fallback - class ReallyUnserializable: - __slots__ = ["__weakref__"] # No __dict__ - - def __iter__(self): - raise TypeError("not iterable") - - with pytest.raises(ValueError) as exc: - jsonable_encoder(ReallyUnserializable()) - assert "not iterable" in str(exc.value) - - def test_jsonable_encoder_nested(self): - data = { - "model": MockPydanticModel(name="test", age=20), - "list": [Decimal("1.1"), {MockEnum.A: Path("/tmp")}], - "set": {1, 2}, - } - expected = { - "model": {"name": "test", "age": 20}, - "list": ["1.1", {"a": "/tmp"}], - "set": [1, 2], - } - assert jsonable_encoder(data) == expected diff --git a/api/tests/unit_tests/graphon/node_events/test_base.py b/api/tests/unit_tests/graphon/node_events/test_base.py deleted file mode 100644 index 4ff1270265..0000000000 --- a/api/tests/unit_tests/graphon/node_events/test_base.py +++ /dev/null @@ -1,19 +0,0 @@ -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.node_events.base import NodeRunResult - - -def test_node_run_result_accepts_trigger_info_metadata() -> None: - result = NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - metadata={ - WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: { - "provider_id": "provider-id", - "event_name": "event-name", - } - }, - ) - - assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == { - "provider_id": "provider-id", - "event_name": "event-name", - } diff --git a/api/tests/unit_tests/graphon/utils/test_json_in_md_parser.py b/api/tests/unit_tests/graphon/utils/test_json_in_md_parser.py deleted file mode 100644 index a8c86d288c..0000000000 --- a/api/tests/unit_tests/graphon/utils/test_json_in_md_parser.py +++ /dev/null @@ -1,75 +0,0 @@ -import pytest - -from graphon.utils.json_in_md_parser import ( - OutputParserError, - parse_and_check_json_markdown, - parse_json_markdown, -) - - -def test_parse_json_markdown_extracts_fenced_json_object() -> None: - src = """ - ```json - {"a": 1, "b": "x"} - ``` - """ - - assert parse_json_markdown(src) == {"a": 1, "b": "x"} - - -def test_parse_json_markdown_extracts_raw_json_array() -> None: - assert parse_json_markdown('[{"a": 1}]') == {"a": 1} - - -def test_parse_json_markdown_raises_when_no_json_block_exists() -> None: - with pytest.raises(ValueError, match="could not find json block"): - parse_json_markdown("plain text only") - - -def test_parse_and_check_json_markdown_unwraps_single_dict_list() -> None: - parsed = parse_and_check_json_markdown( - """ - ```json - [{"present": 1, "other": 2}] - ``` - """, - ["present"], - ) - - assert parsed == {"present": 1, "other": 2} - - -def test_parse_and_check_json_markdown_rejects_invalid_json() -> None: - with pytest.raises(OutputParserError, match="got invalid json object"): - parse_and_check_json_markdown( - """ - ```json - {invalid json} - ``` - """, - [], - ) - - -def test_parse_and_check_json_markdown_rejects_invalid_return_shapes() -> None: - with pytest.raises(OutputParserError, match="got invalid return object"): - parse_and_check_json_markdown( - """ - ```json - [1, 2] - ``` - """, - ["present"], - ) - - -def test_parse_and_check_json_markdown_requires_expected_keys() -> None: - with pytest.raises(OutputParserError, match="expected key `missing`"): - parse_and_check_json_markdown( - """ - ```json - {"present": 1} - ``` - """, - ["present", "missing"], - ) diff --git a/api/tests/unit_tests/libs/_human_input/support.py b/api/tests/unit_tests/libs/_human_input/support.py index bf2d966745..13577b7ca5 100644 --- a/api/tests/unit_tests/libs/_human_input/support.py +++ b/api/tests/unit_tests/libs/_human_input/support.py @@ -7,6 +7,8 @@ from typing import Any from graphon.nodes.human_input.entities import FormInput from graphon.nodes.human_input.enums import TimeoutUnit +from libs.datetime_utils import naive_utc_now + # Exceptions class HumanInputError(Exception): @@ -49,7 +51,7 @@ class HumanInputForm: timeout: int timeout_unit: TimeoutUnit form_token: str | None = None - created_at: datetime = field(default_factory=datetime.utcnow) + created_at: datetime = field(default_factory=naive_utc_now) expires_at: datetime | None = None submitted_at: datetime | None = None submitted_data: dict[str, Any] | None = None @@ -61,7 +63,7 @@ class HumanInputForm: @property def is_expired(self) -> bool: - return self.expires_at is not None and datetime.utcnow() > self.expires_at + return self.expires_at is not None and naive_utc_now() > self.expires_at @property def is_submitted(self) -> bool: @@ -70,7 +72,7 @@ class HumanInputForm: def mark_submitted(self, inputs: dict[str, Any], action: str) -> None: self.submitted_data = inputs self.submitted_action = action - self.submitted_at = datetime.utcnow() + self.submitted_at = naive_utc_now() def submit(self, inputs: dict[str, Any], action: str) -> None: self.mark_submitted(inputs, action) @@ -107,7 +109,7 @@ class FormSubmissionData: form_id: str inputs: dict[str, Any] action: str - submitted_at: datetime = field(default_factory=datetime.utcnow) + submitted_at: datetime = field(default_factory=naive_utc_now) @classmethod def from_request(cls, form_id: str, request: FormSubmissionRequest) -> FormSubmissionData: # type: ignore diff --git a/api/tests/unit_tests/libs/_human_input/test_form_service.py b/api/tests/unit_tests/libs/_human_input/test_form_service.py index 885791f8c9..f1ce1a2c1c 100644 --- a/api/tests/unit_tests/libs/_human_input/test_form_service.py +++ b/api/tests/unit_tests/libs/_human_input/test_form_service.py @@ -2,10 +2,9 @@ Unit tests for FormService. """ -from datetime import datetime, timedelta +from datetime import timedelta import pytest - from graphon.nodes.human_input.entities import ( FormInput, UserAction, @@ -14,6 +13,7 @@ from graphon.nodes.human_input.enums import ( FormInputType, TimeoutUnit, ) + from libs.datetime_utils import naive_utc_now from .support import ( @@ -142,7 +142,7 @@ class TestFormService: # Manually expire the form by modifying expiry time form = form_service.get_form_by_id("form-123") - form.expires_at = datetime.utcnow() - timedelta(hours=1) + form.expires_at = naive_utc_now() - timedelta(hours=1) form_service.repository.save(form) # Should raise FormExpiredError @@ -227,7 +227,7 @@ class TestFormService: # Manually expire the form form = form_service.get_form_by_id("form-123") - form.expires_at = datetime.utcnow() - timedelta(hours=1) + form.expires_at = naive_utc_now() - timedelta(hours=1) form_service.repository.save(form) # Try to submit expired form diff --git a/api/tests/unit_tests/libs/_human_input/test_models.py b/api/tests/unit_tests/libs/_human_input/test_models.py index 8a8b88ff13..0babfbb315 100644 --- a/api/tests/unit_tests/libs/_human_input/test_models.py +++ b/api/tests/unit_tests/libs/_human_input/test_models.py @@ -5,7 +5,6 @@ Unit tests for human input form models. from datetime import datetime, timedelta import pytest - from graphon.nodes.human_input.entities import ( FormInput, UserAction, @@ -15,6 +14,8 @@ from graphon.nodes.human_input.enums import ( TimeoutUnit, ) +from libs.datetime_utils import naive_utc_now + from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm @@ -83,7 +84,7 @@ class TestHumanInputForm: def test_form_expiry_property_expired(self, sample_form_data): """Test is_expired property for expired form.""" # Create form with past expiry - past_time = datetime.utcnow() - timedelta(hours=1) + past_time = naive_utc_now() - timedelta(hours=1) sample_form_data["created_at"] = past_time form = HumanInputForm(**sample_form_data) @@ -111,9 +112,9 @@ class TestHumanInputForm: """Test form submit method.""" form = HumanInputForm(**sample_form_data) - submission_time_before = datetime.utcnow() + submission_time_before = naive_utc_now() form.submit({"input": "test value"}, "submit") - submission_time_after = datetime.utcnow() + submission_time_after = naive_utc_now() assert form.is_submitted assert form.submitted_data == {"input": "test value"} @@ -213,11 +214,11 @@ class TestFormSubmissionData: def test_submission_data_timestamps(self): """Test submission data timestamp handling.""" - before_time = datetime.utcnow() + before_time = naive_utc_now() submission_data = FormSubmissionData(form_id="form-123", inputs={"test": "value"}, action="submit") - after_time = datetime.utcnow() + after_time = naive_utc_now() assert before_time <= submission_data.submitted_at <= after_time diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py index bb3a6db1a1..86163f1554 100644 --- a/api/tests/unit_tests/models/test_conversation_variable.py +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -1,7 +1,8 @@ from uuid import uuid4 -from factories import variable_factory from graphon.variables import SegmentType + +from factories import variable_factory from models import ConversationVariable diff --git a/api/tests/unit_tests/models/test_model.py b/api/tests/unit_tests/models/test_model.py index e21f0e4fbd..a5909f60a8 100644 --- a/api/tests/unit_tests/models/test_model.py +++ b/api/tests/unit_tests/models/test_model.py @@ -2,9 +2,9 @@ import importlib import types import pytest +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod from core.workflow.file_reference import build_file_reference -from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod from models.model import Conversation, Message diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index 550441539a..e7c0479757 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -3,14 +3,14 @@ import json from unittest import mock from uuid import uuid4 +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable +from graphon.variables.segments import IntegerSegment, Segment + from constants import HIDDEN_VALUE from core.helper import encrypter from core.workflow.file_reference import build_file_reference from factories.variable_factory import build_segment -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File -from graphon.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable -from graphon.variables.segments import IntegerSegment, Segment from models.workflow import ( Workflow, WorkflowDraftVariable, diff --git a/api/tests/unit_tests/models/test_workflow_models.py b/api/tests/unit_tests/models/test_workflow_models.py index eb9fef7587..507e1c8c3a 100644 --- a/api/tests/unit_tests/models/test_workflow_models.py +++ b/api/tests/unit_tests/models/test_workflow_models.py @@ -13,12 +13,12 @@ from datetime import UTC, datetime from uuid import uuid4 import pytest - from graphon.enums import ( BuiltinNodeTypes, WorkflowExecutionStatus, WorkflowNodeExecutionStatus, ) + from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import ( Workflow, diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py index ccc9c93815..10850970d8 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -9,11 +9,6 @@ from decimal import Decimal from unittest.mock import MagicMock, PropertyMock import pytest -from pytest_mock import MockerFixture -from sqlalchemy.orm import Session, sessionmaker - -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.repositories.factory import OrderConfig from graphon.entities import ( WorkflowNodeExecution, ) @@ -23,6 +18,11 @@ from graphon.enums import ( WorkflowNodeExecutionStatus, ) from graphon.model_runtime.utils.encoders import jsonable_encoder +from pytest_mock import MockerFixture +from sqlalchemy.orm import Session, sessionmaker + +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from models.account import Account, Tenant from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py index e8c094b75d..2322be9e80 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py @@ -6,13 +6,13 @@ from datetime import datetime from typing import Any from unittest.mock import MagicMock, Mock +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from graphon.entities.workflow_node_execution import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py deleted file mode 100644 index c6c3f677fb..0000000000 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py +++ /dev/null @@ -1,387 +0,0 @@ -import json -from unittest.mock import Mock, patch - -import pytest - -from models.source import DataSourceApiKeyAuthBinding -from services.auth.api_key_auth_service import ApiKeyAuthService - - -class TestApiKeyAuthService: - """API key authentication service security tests""" - - def setup_method(self): - """Setup test fixtures""" - self.tenant_id = "test_tenant_123" - self.category = "search" - self.provider = "google" - self.binding_id = "binding_123" - self.mock_credentials = {"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}} - self.mock_args = {"category": self.category, "provider": self.provider, "credentials": self.mock_credentials} - - @patch("services.auth.api_key_auth_service.db.session") - def test_get_provider_auth_list_success(self, mock_session): - """Test get provider auth list - success scenario""" - # Mock database query result - mock_binding = Mock() - mock_binding.tenant_id = self.tenant_id - mock_binding.provider = self.provider - mock_binding.disabled = False - - mock_session.scalars.return_value.all.return_value = [mock_binding] - - result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) - - assert len(result) == 1 - assert result[0].tenant_id == self.tenant_id - assert mock_session.scalars.call_count == 1 - select_arg = mock_session.scalars.call_args[0][0] - assert "data_source_api_key_auth_binding" in str(select_arg).lower() - - @patch("services.auth.api_key_auth_service.db.session") - def test_get_provider_auth_list_empty(self, mock_session): - """Test get provider auth list - empty result""" - mock_session.scalars.return_value.all.return_value = [] - - result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) - - assert result == [] - - @patch("services.auth.api_key_auth_service.db.session") - def test_get_provider_auth_list_filters_disabled(self, mock_session): - """Test get provider auth list - filters disabled items""" - mock_session.scalars.return_value.all.return_value = [] - - ApiKeyAuthService.get_provider_auth_list(self.tenant_id) - select_stmt = mock_session.scalars.call_args[0][0] - where_clauses = list(getattr(select_stmt, "_where_criteria", []) or []) - # Ensure both tenant filter and disabled filter exist - where_strs = [str(c).lower() for c in where_clauses] - assert any("tenant_id" in s for s in where_strs) - assert any("disabled" in s for s in where_strs) - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - @patch("services.auth.api_key_auth_service.encrypter") - def test_create_provider_auth_success(self, mock_encrypter, mock_factory, mock_session): - """Test create provider auth - success scenario""" - # Mock successful auth validation - mock_auth_instance = Mock() - mock_auth_instance.validate_credentials.return_value = True - mock_factory.return_value = mock_auth_instance - - # Mock encryption - encrypted_key = "encrypted_test_key_123" - mock_encrypter.encrypt_token.return_value = encrypted_key - - # Mock database operations - mock_session.add = Mock() - mock_session.commit = Mock() - - ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) - - # Verify factory class calls - mock_factory.assert_called_once_with(self.provider, self.mock_credentials) - mock_auth_instance.validate_credentials.assert_called_once() - - # Verify encryption calls - mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, "test_secret_key_123") - - # Verify database operations - mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - def test_create_provider_auth_validation_failed(self, mock_factory, mock_session): - """Test create provider auth - validation failed""" - # Mock failed auth validation - mock_auth_instance = Mock() - mock_auth_instance.validate_credentials.return_value = False - mock_factory.return_value = mock_auth_instance - - ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) - - # Verify no database operations when validation fails - mock_session.add.assert_not_called() - mock_session.commit.assert_not_called() - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - @patch("services.auth.api_key_auth_service.encrypter") - def test_create_provider_auth_encrypts_api_key(self, mock_encrypter, mock_factory, mock_session): - """Test create provider auth - ensures API key is encrypted""" - # Mock successful auth validation - mock_auth_instance = Mock() - mock_auth_instance.validate_credentials.return_value = True - mock_factory.return_value = mock_auth_instance - - # Mock encryption - encrypted_key = "encrypted_test_key_123" - mock_encrypter.encrypt_token.return_value = encrypted_key - - # Mock database operations - mock_session.add = Mock() - mock_session.commit = Mock() - - args_copy = self.mock_args.copy() - original_key = args_copy["credentials"]["config"]["api_key"] - - ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy) - - # Verify original key is replaced with encrypted key - assert args_copy["credentials"]["config"]["api_key"] == encrypted_key - assert args_copy["credentials"]["config"]["api_key"] != original_key - - # Verify encryption function is called correctly - mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key) - - @patch("services.auth.api_key_auth_service.db.session") - def test_get_auth_credentials_success(self, mock_session): - """Test get auth credentials - success scenario""" - # Mock database query result - mock_binding = Mock() - mock_binding.credentials = json.dumps(self.mock_credentials) - mock_session.query.return_value.where.return_value.first.return_value = mock_binding - mock_session.query.return_value.where.return_value.first.return_value = mock_binding - - result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) - - assert result == self.mock_credentials - mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding) - - @patch("services.auth.api_key_auth_service.db.session") - def test_get_auth_credentials_not_found(self, mock_session): - """Test get auth credentials - not found""" - mock_session.query.return_value.where.return_value.first.return_value = None - - result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) - - assert result is None - - @patch("services.auth.api_key_auth_service.db.session") - def test_get_auth_credentials_filters_correctly(self, mock_session): - """Test get auth credentials - applies correct filters""" - mock_session.query.return_value.where.return_value.first.return_value = None - - ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) - - # Verify where conditions are correct - where_call = mock_session.query.return_value.where.call_args[0] - assert len(where_call) == 4 # tenant_id, category, provider, disabled - - @patch("services.auth.api_key_auth_service.db.session") - def test_get_auth_credentials_json_parsing(self, mock_session): - """Test get auth credentials - JSON parsing""" - # Mock credentials with special characters - special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}} - - mock_binding = Mock() - mock_binding.credentials = json.dumps(special_credentials, ensure_ascii=False) - mock_session.query.return_value.where.return_value.first.return_value = mock_binding - - result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) - - assert result == special_credentials - assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%" - - @patch("services.auth.api_key_auth_service.db.session") - def test_delete_provider_auth_success(self, mock_session): - """Test delete provider auth - success scenario""" - # Mock database query result - mock_binding = Mock() - mock_session.query.return_value.where.return_value.first.return_value = mock_binding - - ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id) - - # Verify delete operations - mock_session.delete.assert_called_once_with(mock_binding) - mock_session.commit.assert_called_once() - - @patch("services.auth.api_key_auth_service.db.session") - def test_delete_provider_auth_not_found(self, mock_session): - """Test delete provider auth - not found""" - mock_session.query.return_value.where.return_value.first.return_value = None - - ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id) - - # Verify no delete operations when not found - mock_session.delete.assert_not_called() - mock_session.commit.assert_not_called() - - @patch("services.auth.api_key_auth_service.db.session") - def test_delete_provider_auth_filters_by_tenant(self, mock_session): - """Test delete provider auth - filters by tenant""" - mock_session.query.return_value.where.return_value.first.return_value = None - - ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id) - - # Verify where conditions include tenant_id and binding_id - where_call = mock_session.query.return_value.where.call_args[0] - assert len(where_call) == 2 - - def test_validate_api_key_auth_args_success(self): - """Test API key auth args validation - success scenario""" - # Should not raise any exception - ApiKeyAuthService.validate_api_key_auth_args(self.mock_args) - - def test_validate_api_key_auth_args_missing_category(self): - """Test API key auth args validation - missing category""" - args = self.mock_args.copy() - del args["category"] - - with pytest.raises(ValueError, match="category is required"): - ApiKeyAuthService.validate_api_key_auth_args(args) - - def test_validate_api_key_auth_args_empty_category(self): - """Test API key auth args validation - empty category""" - args = self.mock_args.copy() - args["category"] = "" - - with pytest.raises(ValueError, match="category is required"): - ApiKeyAuthService.validate_api_key_auth_args(args) - - def test_validate_api_key_auth_args_missing_provider(self): - """Test API key auth args validation - missing provider""" - args = self.mock_args.copy() - del args["provider"] - - with pytest.raises(ValueError, match="provider is required"): - ApiKeyAuthService.validate_api_key_auth_args(args) - - def test_validate_api_key_auth_args_empty_provider(self): - """Test API key auth args validation - empty provider""" - args = self.mock_args.copy() - args["provider"] = "" - - with pytest.raises(ValueError, match="provider is required"): - ApiKeyAuthService.validate_api_key_auth_args(args) - - def test_validate_api_key_auth_args_missing_credentials(self): - """Test API key auth args validation - missing credentials""" - args = self.mock_args.copy() - del args["credentials"] - - with pytest.raises(ValueError, match="credentials is required"): - ApiKeyAuthService.validate_api_key_auth_args(args) - - def test_validate_api_key_auth_args_empty_credentials(self): - """Test API key auth args validation - empty credentials""" - args = self.mock_args.copy() - args["credentials"] = None - - with pytest.raises(ValueError, match="credentials is required"): - ApiKeyAuthService.validate_api_key_auth_args(args) - - def test_validate_api_key_auth_args_invalid_credentials_type(self): - """Test API key auth args validation - invalid credentials type""" - args = self.mock_args.copy() - args["credentials"] = "not_a_dict" - - with pytest.raises(ValueError, match="credentials must be a dictionary"): - ApiKeyAuthService.validate_api_key_auth_args(args) - - def test_validate_api_key_auth_args_missing_auth_type(self): - """Test API key auth args validation - missing auth_type""" - args = self.mock_args.copy() - del args["credentials"]["auth_type"] - - with pytest.raises(ValueError, match="auth_type is required"): - ApiKeyAuthService.validate_api_key_auth_args(args) - - def test_validate_api_key_auth_args_empty_auth_type(self): - """Test API key auth args validation - empty auth_type""" - args = self.mock_args.copy() - args["credentials"]["auth_type"] = "" - - with pytest.raises(ValueError, match="auth_type is required"): - ApiKeyAuthService.validate_api_key_auth_args(args) - - @pytest.mark.parametrize( - "malicious_input", - [ - "", - "'; DROP TABLE users; --", - "../../../etc/passwd", - "\\x00\\x00", # null bytes - "A" * 10000, # very long input - ], - ) - def test_validate_api_key_auth_args_malicious_input(self, malicious_input): - """Test API key auth args validation - malicious input""" - args = self.mock_args.copy() - args["category"] = malicious_input - - # Verify parameter validator doesn't crash on malicious input - # Should validate normally rather than raising security-related exceptions - ApiKeyAuthService.validate_api_key_auth_args(args) - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - @patch("services.auth.api_key_auth_service.encrypter") - def test_create_provider_auth_database_error_handling(self, mock_encrypter, mock_factory, mock_session): - """Test create provider auth - database error handling""" - # Mock successful auth validation - mock_auth_instance = Mock() - mock_auth_instance.validate_credentials.return_value = True - mock_factory.return_value = mock_auth_instance - - # Mock encryption - mock_encrypter.encrypt_token.return_value = "encrypted_key" - - # Mock database error - mock_session.commit.side_effect = Exception("Database error") - - with pytest.raises(Exception, match="Database error"): - ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) - - @patch("services.auth.api_key_auth_service.db.session") - def test_get_auth_credentials_invalid_json(self, mock_session): - """Test get auth credentials - invalid JSON""" - # Mock database returning invalid JSON - mock_binding = Mock() - mock_binding.credentials = "invalid json content" - mock_session.query.return_value.where.return_value.first.return_value = mock_binding - - with pytest.raises(json.JSONDecodeError): - ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - def test_create_provider_auth_factory_exception(self, mock_factory, mock_session): - """Test create provider auth - factory exception""" - # Mock factory raising exception - mock_factory.side_effect = Exception("Factory error") - - with pytest.raises(Exception, match="Factory error"): - ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - @patch("services.auth.api_key_auth_service.encrypter") - def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, mock_session): - """Test create provider auth - encryption exception""" - # Mock successful auth validation - mock_auth_instance = Mock() - mock_auth_instance.validate_credentials.return_value = True - mock_factory.return_value = mock_auth_instance - - # Mock encryption exception - mock_encrypter.encrypt_token.side_effect = Exception("Encryption error") - - with pytest.raises(Exception, match="Encryption error"): - ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) - - def test_validate_api_key_auth_args_none_input(self): - """Test API key auth args validation - None input""" - with pytest.raises(TypeError): - ApiKeyAuthService.validate_api_key_auth_args(None) - - def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self): - """Test API key auth args validation - dict credentials with list auth_type""" - args = self.mock_args.copy() - args["credentials"]["auth_type"] = ["api_key"] - - # Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy - # So this should not raise exception, this test should pass - ApiKeyAuthService.validate_api_key_auth_args(args) diff --git a/api/tests/unit_tests/services/auth/test_auth_integration.py b/api/tests/unit_tests/services/auth/test_auth_integration.py deleted file mode 100644 index 3832a0b8b2..0000000000 --- a/api/tests/unit_tests/services/auth/test_auth_integration.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -API Key Authentication System Integration Tests -""" - -import json -from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock, patch - -import httpx -import pytest - -from services.auth.api_key_auth_factory import ApiKeyAuthFactory -from services.auth.api_key_auth_service import ApiKeyAuthService -from services.auth.auth_type import AuthType - - -class TestAuthIntegration: - def setup_method(self): - self.tenant_id_1 = "tenant_123" - self.tenant_id_2 = "tenant_456" # For multi-tenant isolation testing - self.category = "search" - - # Realistic authentication configurations - self.firecrawl_credentials = {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}} - self.jina_credentials = {"auth_type": "bearer", "config": {"api_key": "jina_test_key_456"}} - self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}} - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.firecrawl.firecrawl.httpx.post") - @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") - def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session): - """Test complete authentication flow: request → validation → encryption → storage""" - mock_http.return_value = self._create_success_response() - mock_encrypt.return_value = "encrypted_fc_test_key_123" - mock_session.add = Mock() - mock_session.commit = Mock() - - args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} - ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) - - mock_http.assert_called_once() - call_args = mock_http.call_args - assert "https://api.firecrawl.dev/v1/crawl" in call_args[0][0] - assert call_args[1]["headers"]["Authorization"] == "Bearer fc_test_key_123" - - mock_encrypt.assert_called_once_with(self.tenant_id_1, "fc_test_key_123") - mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() - - @patch("services.auth.firecrawl.firecrawl.httpx.post") - def test_cross_component_integration(self, mock_http): - """Test factory → provider → HTTP call integration""" - mock_http.return_value = self._create_success_response() - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials) - result = factory.validate_credentials() - - assert result is True - mock_http.assert_called_once() - - @patch("services.auth.api_key_auth_service.db.session") - def test_multi_tenant_isolation(self, mock_session): - """Ensure complete tenant data isolation""" - tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials) - tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials) - - mock_session.scalars.return_value.all.return_value = [tenant1_binding] - result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1) - - mock_session.scalars.return_value.all.return_value = [tenant2_binding] - result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2) - - assert len(result1) == 1 - assert result1[0].tenant_id == self.tenant_id_1 - assert len(result2) == 1 - assert result2[0].tenant_id == self.tenant_id_2 - - @patch("services.auth.api_key_auth_service.db.session") - def test_cross_tenant_access_prevention(self, mock_session): - """Test prevention of cross-tenant credential access""" - mock_session.query.return_value.where.return_value.first.return_value = None - - result = ApiKeyAuthService.get_auth_credentials(self.tenant_id_2, self.category, AuthType.FIRECRAWL) - - assert result is None - - def test_sensitive_data_protection(self): - """Ensure API keys don't leak to logs""" - credentials_with_secrets = { - "auth_type": "bearer", - "config": {"api_key": "super_secret_key_do_not_log", "secret": "another_secret"}, - } - - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, credentials_with_secrets) - factory_str = str(factory) - - assert "super_secret_key_do_not_log" not in factory_str - assert "another_secret" not in factory_str - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.firecrawl.firecrawl.httpx.post") - @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") - def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session): - """Test concurrent authentication creation safety""" - mock_http.return_value = self._create_success_response() - mock_encrypt.return_value = "encrypted_key" - mock_session.add = Mock() - mock_session.commit = Mock() - - args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} - - results = [] - exceptions = [] - - def create_auth(): - try: - ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) - results.append("success") - except Exception as e: - exceptions.append(e) - - with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(create_auth) for _ in range(5)] - for future in futures: - future.result() - - assert len(results) == 5 - assert len(exceptions) == 0 - assert mock_session.add.call_count == 5 - assert mock_session.commit.call_count == 5 - - @pytest.mark.parametrize( - "invalid_input", - [ - None, # Null input - {}, # Empty dictionary - missing required fields - {"auth_type": "bearer"}, # Missing config section - {"auth_type": "bearer", "config": {}}, # Missing api_key - ], - ) - def test_invalid_input_boundary(self, invalid_input): - """Test boundary handling for invalid inputs""" - with pytest.raises((ValueError, KeyError, TypeError, AttributeError)): - ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input) - - @patch("services.auth.firecrawl.firecrawl.httpx.post") - def test_http_error_handling(self, mock_http): - """Test proper HTTP error handling""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.text = '{"error": "Unauthorized"}' - mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized") - mock_http.return_value = mock_response - - # PT012: Split into single statement for pytest.raises - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials) - with pytest.raises((httpx.HTTPError, Exception)): - factory.validate_credentials() - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.firecrawl.firecrawl.httpx.post") - def test_network_failure_recovery(self, mock_http, mock_session): - """Test system recovery from network failures""" - mock_http.side_effect = httpx.RequestError("Network timeout") - mock_session.add = Mock() - mock_session.commit = Mock() - - args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} - - with pytest.raises(httpx.RequestError): - ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) - - mock_session.commit.assert_not_called() - - @pytest.mark.parametrize( - ("provider", "credentials"), - [ - (AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "fc_key"}}), - (AuthType.JINA, {"auth_type": "bearer", "config": {"api_key": "jina_key"}}), - (AuthType.WATERCRAWL, {"auth_type": "x-api-key", "config": {"api_key": "wc_key"}}), - ], - ) - def test_all_providers_factory_creation(self, provider, credentials): - """Test factory creation for all supported providers""" - auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider) - assert auth_class is not None - - factory = ApiKeyAuthFactory(provider, credentials) - assert factory.auth is not None - - def _create_success_response(self, status_code=200): - """Create successful HTTP response mock""" - mock_response = Mock() - mock_response.status_code = status_code - mock_response.json.return_value = {"status": "success"} - mock_response.raise_for_status.return_value = None - return mock_response - - def _create_mock_binding(self, tenant_id: str, provider: str, credentials: dict) -> Mock: - """Create realistic database binding mock""" - mock_binding = Mock() - mock_binding.id = f"binding_{provider}_{tenant_id}" - mock_binding.tenant_id = tenant_id - mock_binding.category = self.category - mock_binding.provider = provider - mock_binding.credentials = json.dumps(credentials, ensure_ascii=False) - mock_binding.disabled = False - - mock_binding.created_at = Mock() - mock_binding.created_at.timestamp.return_value = 1640995200 - mock_binding.updated_at = Mock() - mock_binding.updated_at.timestamp.return_value = 1640995200 - - return mock_binding - - def test_integration_coverage_validation(self): - """Validate integration test coverage meets quality standards""" - core_scenarios = { - "business_logic": ["end_to_end_auth_flow", "cross_component_integration"], - "security": ["multi_tenant_isolation", "cross_tenant_access_prevention", "sensitive_data_protection"], - "reliability": ["concurrent_creation_safety", "network_failure_recovery"], - "compatibility": ["all_providers_factory_creation"], - "boundaries": ["invalid_input_boundary", "http_error_handling"], - } - - total_scenarios = sum(len(scenarios) for scenarios in core_scenarios.values()) - assert total_scenarios >= 10 - - security_tests = core_scenarios["security"] - assert "multi_tenant_isolation" in security_tests - assert "sensitive_data_protection" in security_tests - assert True diff --git a/api/tests/unit_tests/services/dataset_service_test_helpers.py b/api/tests/unit_tests/services/dataset_service_test_helpers.py index c95b60fad0..ef73bc0e01 100644 --- a/api/tests/unit_tests/services/dataset_service_test_helpers.py +++ b/api/tests/unit_tests/services/dataset_service_test_helpers.py @@ -10,6 +10,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, create_autospec, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from werkzeug.exceptions import Forbidden, NotFound from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError @@ -17,7 +18,6 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod from enums.cloud_plan import CloudPlan -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from models import Account, TenantAccountRole from models.dataset import ( ChildChunk, diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 3358c8b44d..7c36e9d960 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -109,10 +109,10 @@ This test suite follows a comprehensive testing strategy that covers: from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelType from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, DatasetProcessRule, Document from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import ( diff --git a/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py b/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py deleted file mode 100644 index f9d901fca2..0000000000 --- a/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py +++ /dev/null @@ -1,311 +0,0 @@ -import datetime -from unittest.mock import MagicMock, patch - -import pytest - -from services.retention.conversation.messages_clean_policy import ( - BillingDisabledPolicy, -) -from services.retention.conversation.messages_clean_service import MessagesCleanService - - -class TestMessagesCleanService: - @pytest.fixture(autouse=True) - def mock_db_engine(self): - with patch("services.retention.conversation.messages_clean_service.db") as mock_db: - mock_db.engine = MagicMock() - yield mock_db.engine - - @pytest.fixture - def mock_db_session(self, mock_db_engine): - with patch("services.retention.conversation.messages_clean_service.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - yield mock_session - - @pytest.fixture - def mock_policy(self): - policy = MagicMock(spec=BillingDisabledPolicy) - return policy - - def test_run_calls_clean_messages(self, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - with patch.object(service, "_clean_messages_by_time_range") as mock_clean: - mock_clean.return_value = {"total_deleted": 5} - result = service.run() - assert result == {"total_deleted": 5} - mock_clean.assert_called_once() - - def test_clean_messages_by_time_range_basic(self, mock_db_session, mock_policy): - # Arrange - end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) - service = MessagesCleanService( - policy=mock_policy, - end_before=end_before, - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - MagicMock( - rowcount=1 - ), # delete relations (this is wrong, relations delete doesn't use rowcount here, but execute) - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete messages - MagicMock(all=lambda: []), # next batch empty - ] - - # Reset side_effect to be more robust - # The service calls session.execute for: - # 1. Fetch messages - # 2. Fetch apps - # 3. Batch delete relations (8 calls if IDs exist) - # 4. Delete messages - - mock_returns = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # fetch messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # fetch apps - ] - # 8 deletes for relations - mock_returns.extend([MagicMock() for _ in range(8)]) - # 1 delete for messages - mock_returns.append(MagicMock(rowcount=1)) - # Final fetch messages (empty) - mock_returns.append(MagicMock(all=lambda: [])) - - mock_db_session.execute.side_effect = mock_returns - mock_policy.filter_message_ids.return_value = ["msg1"] - - # Act - with patch("services.retention.conversation.messages_clean_service.time.sleep"): - stats = service.run() - - # Assert - assert stats["total_messages"] == 1 - assert stats["total_deleted"] == 1 - assert stats["batches"] == 2 - - def test_clean_messages_by_time_range_with_start_from(self, mock_db_session, mock_policy): - start_from = datetime.datetime(2024, 1, 1, 0, 0, 0) - end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) - service = MessagesCleanService( - policy=mock_policy, - start_from=start_from, - end_before=end_before, - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: []), # No messages - ] - - stats = service.run() - assert stats["total_messages"] == 0 - - def test_clean_messages_by_time_range_with_cursor(self, mock_db_session, mock_policy): - # Test pagination with cursor - end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) - service = MessagesCleanService( - policy=mock_policy, - end_before=end_before, - batch_size=1, - ) - - msg1_time = datetime.datetime(2024, 1, 1, 10, 0, 0) - msg2_time = datetime.datetime(2024, 1, 1, 11, 0, 0) - - mock_returns = [] - # Batch 1 - mock_returns.append(MagicMock(all=lambda: [("msg1", "app1", msg1_time)])) - mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")])) - mock_returns.extend([MagicMock() for _ in range(8)]) # relations - mock_returns.append(MagicMock(rowcount=1)) # messages - - # Batch 2 - mock_returns.append(MagicMock(all=lambda: [("msg2", "app1", msg2_time)])) - mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")])) - mock_returns.extend([MagicMock() for _ in range(8)]) # relations - mock_returns.append(MagicMock(rowcount=1)) # messages - - # Batch 3 - mock_returns.append(MagicMock(all=lambda: [])) - - mock_db_session.execute.side_effect = mock_returns - mock_policy.filter_message_ids.return_value = ["msg1"] # Simplified - - with patch("services.retention.conversation.messages_clean_service.time.sleep"): - stats = service.run() - - assert stats["batches"] == 3 - assert stats["total_messages"] == 2 - - def test_clean_messages_by_time_range_dry_run(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - dry_run=True, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - MagicMock(all=lambda: []), # next batch empty - ] - mock_policy.filter_message_ids.return_value = ["msg1"] - - with patch("services.retention.conversation.messages_clean_service.random.sample") as mock_sample: - mock_sample.return_value = ["msg1"] - stats = service.run() - assert stats["filtered_messages"] == 1 - assert stats["total_deleted"] == 0 # Dry run - mock_sample.assert_called() - - def test_clean_messages_by_time_range_no_apps_found(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: []), # apps NOT found - MagicMock(all=lambda: []), # next batch empty - ] - - stats = service.run() - assert stats["total_messages"] == 1 - assert stats["total_deleted"] == 0 - - def test_clean_messages_by_time_range_no_app_ids(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: []), # next batch empty - ] - - # We need to successfully execute line 228 and 229, then return empty at 251. - # line 228: raw_messages = list(session.execute(msg_stmt).all()) - # line 251: app_ids = list({msg.app_id for msg in messages}) - - calls = [] - - def list_side_effect(arg): - calls.append(arg) - if len(calls) == 2: # This is the second call to list() in the loop - return [] - return list(arg) - - with patch("services.retention.conversation.messages_clean_service.list", side_effect=list_side_effect): - stats = service.run() - assert stats["batches"] == 2 - assert stats["total_messages"] == 1 - - def test_from_time_range_validation(self, mock_policy): - now = datetime.datetime.now() - # Test start_from >= end_before - with pytest.raises(ValueError, match="start_from .* must be less than end_before"): - MessagesCleanService.from_time_range(mock_policy, now, now) - - # Test batch_size <= 0 - with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): - MessagesCleanService.from_time_range(mock_policy, now - datetime.timedelta(days=1), now, batch_size=0) - - def test_from_time_range_success(self, mock_policy): - start = datetime.datetime(2024, 1, 1) - end = datetime.datetime(2024, 2, 1) - # Mock logger to avoid actual logging if needed, though it's fine - service = MessagesCleanService.from_time_range(mock_policy, start, end) - assert service._start_from == start - assert service._end_before == end - - def test_from_days_validation(self, mock_policy): - # Test days < 0 - with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"): - MessagesCleanService.from_days(mock_policy, days=-1) - - # Test batch_size <= 0 - with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): - MessagesCleanService.from_days(mock_policy, days=30, batch_size=0) - - def test_from_days_success(self, mock_policy): - with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: - fixed_now = datetime.datetime(2024, 6, 1) - mock_now.return_value = fixed_now - - service = MessagesCleanService.from_days(mock_policy, days=10) - assert service._start_from is None - assert service._end_before == fixed_now - datetime.timedelta(days=10) - - def test_clean_messages_by_time_range_no_messages_to_delete(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - MagicMock(all=lambda: []), # next batch empty - ] - mock_policy.filter_message_ids.return_value = [] # Policy says NO - - stats = service.run() - assert stats["total_messages"] == 1 - assert stats["filtered_messages"] == 0 - assert stats["total_deleted"] == 0 - - def test_batch_delete_message_relations_empty(self, mock_db_session): - MessagesCleanService._batch_delete_message_relations(mock_db_session, []) - mock_db_session.execute.assert_not_called() - - def test_batch_delete_message_relations_with_ids(self, mock_db_session): - MessagesCleanService._batch_delete_message_relations(mock_db_session, ["msg1", "msg2"]) - assert mock_db_session.execute.call_count == 8 # 8 tables to clean up - - def test_clean_messages_interval_from_env(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_returns = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - ] - mock_returns.extend([MagicMock() for _ in range(8)]) # relations - mock_returns.append(MagicMock(rowcount=1)) # messages - mock_returns.append(MagicMock(all=lambda: [])) # next batch empty - - mock_db_session.execute.side_effect = mock_returns - mock_policy.filter_message_ids.return_value = ["msg1"] - - with patch( - "services.retention.conversation.messages_clean_service.dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", - 500, - ): - with patch("services.retention.conversation.messages_clean_service.time.sleep") as mock_sleep: - with patch("services.retention.conversation.messages_clean_service.random.uniform") as mock_uniform: - mock_uniform.return_value = 300.0 - service.run() - mock_uniform.assert_called_with(0, 500) - mock_sleep.assert_called_with(0.3) diff --git a/api/tests/unit_tests/services/test_app_dsl_service.py b/api/tests/unit_tests/services/test_app_dsl_service.py index afea8ec92a..179518a5fa 100644 --- a/api/tests/unit_tests/services/test_app_dsl_service.py +++ b/api/tests/unit_tests/services/test_app_dsl_service.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock import pytest import yaml +from graphon.enums import BuiltinNodeTypes from core.trigger.constants import ( TRIGGER_PLUGIN_NODE_TYPE, TRIGGER_SCHEDULE_NODE_TYPE, TRIGGER_WEBHOOK_NODE_TYPE, ) -from graphon.enums import BuiltinNodeTypes from models import Account, AppMode from models.model import IconType from services import app_dsl_service diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index 35157790ca..1bf4c0e172 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -6,13 +6,14 @@ Tests are organized by functionality and include edge cases, error handling, and both positive and negative test scenarios. """ -from datetime import datetime, timedelta +from datetime import timedelta from unittest.mock import MagicMock, Mock, create_autospec, patch import pytest from sqlalchemy import asc, desc from core.app.entities.app_invoke_entities import InvokeFrom +from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable from models.enums import ConversationFromSource @@ -122,8 +123,8 @@ class ConversationServiceTestDataFactory: conversation.is_deleted = kwargs.get("is_deleted", False) conversation.name = kwargs.get("name", "Test Conversation") conversation.status = kwargs.get("status", "normal") - conversation.created_at = kwargs.get("created_at", datetime.utcnow()) - conversation.updated_at = kwargs.get("updated_at", datetime.utcnow()) + conversation.created_at = kwargs.get("created_at", naive_utc_now()) + conversation.updated_at = kwargs.get("updated_at", naive_utc_now()) for key, value in kwargs.items(): setattr(conversation, key, value) return conversation @@ -152,7 +153,7 @@ class ConversationServiceTestDataFactory: message.conversation_id = conversation_id message.app_id = app_id message.query = kwargs.get("query", "Test message content") - message.created_at = kwargs.get("created_at", datetime.utcnow()) + message.created_at = kwargs.get("created_at", naive_utc_now()) for key, value in kwargs.items(): setattr(message, key, value) return message @@ -181,8 +182,8 @@ class ConversationServiceTestDataFactory: variable.conversation_id = conversation_id variable.app_id = app_id variable.data = {"name": kwargs.get("name", "test_var"), "value": kwargs.get("value", "test_value")} - variable.created_at = kwargs.get("created_at", datetime.utcnow()) - variable.updated_at = kwargs.get("updated_at", datetime.utcnow()) + variable.created_at = kwargs.get("created_at", naive_utc_now()) + variable.updated_at = kwargs.get("updated_at", naive_utc_now()) # Mock to_variable method mock_variable = Mock() @@ -302,7 +303,7 @@ class TestConversationServiceHelpers: """ # Arrange mock_conversation = ConversationServiceTestDataFactory.create_conversation_mock() - mock_conversation.updated_at = datetime.utcnow() + mock_conversation.updated_at = naive_utc_now() # Act condition = ConversationService._build_filter_condition( @@ -323,7 +324,7 @@ class TestConversationServiceHelpers: """ # Arrange mock_conversation = ConversationServiceTestDataFactory.create_conversation_mock() - mock_conversation.created_at = datetime.utcnow() + mock_conversation.created_at = naive_utc_now() # Act condition = ConversationService._build_filter_condition( @@ -668,9 +669,9 @@ class TestConversationServiceConversationalVariable: mock_session_factory.create_session.return_value.__enter__.return_value = mock_session last_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock( - created_at=datetime.utcnow() - timedelta(hours=1) + created_at=naive_utc_now() - timedelta(hours=1) ) - variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(created_at=datetime.utcnow()) + variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(created_at=naive_utc_now()) mock_session.scalar.return_value = last_variable mock_session.scalars.return_value.all.return_value = [variable] diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index da93239600..3df7d500cf 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -1,10 +1,10 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy.orm import Session from core.plugin.entities.plugin_daemon import CredentialType -from graphon.model_runtime.entities.provider_entities import FormType from models.account import Account from models.model import EndUser from models.oauth import DatasourceProvider diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index 0aeecd938f..9be475d043 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -3,18 +3,19 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest - -import services.human_input_service as human_input_service_module -from core.repositories.human_input_repository import ( - HumanInputFormRecord, - HumanInputFormSubmissionRepository, -) from graphon.nodes.human_input.entities import ( FormDefinition, FormInput, UserAction, ) from graphon.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus + +import services.human_input_service as human_input_service_module +from core.repositories.human_input_repository import ( + HumanInputFormRecord, + HumanInputFormSubmissionRepository, +) +from libs.datetime_utils import naive_utc_now from models.human_input import RecipientType from services.human_input_service import ( Form, @@ -51,11 +52,11 @@ def sample_form_record(): inputs=[], user_actions=[UserAction(id="submit", title="Submit")], rendered_content="

hello

", - expiration_time=datetime.utcnow() + timedelta(hours=1), + expiration_time=naive_utc_now() + timedelta(hours=1), ), rendered_content="

hello

", - created_at=datetime.utcnow(), - expiration_time=datetime.utcnow() + timedelta(hours=1), + created_at=naive_utc_now(), + expiration_time=naive_utc_now() + timedelta(hours=1), status=HumanInputFormStatus.WAITING, selected_action_id=None, submitted_data=None, @@ -101,8 +102,8 @@ def test_ensure_form_active_respects_global_timeout(monkeypatch, sample_form_rec service = HumanInputService(session_factory) expired_record = dataclasses.replace( sample_form_record, - created_at=datetime.utcnow() - timedelta(hours=2), - expiration_time=datetime.utcnow() + timedelta(hours=2), + created_at=naive_utc_now() - timedelta(hours=2), + expiration_time=naive_utc_now() + timedelta(hours=2), ) monkeypatch.setattr(human_input_service_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600) @@ -391,7 +392,7 @@ def test_ensure_form_active_errors(sample_form_record, mock_session_factory): service = HumanInputService(session_factory) # Submitted - submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow()) + submitted_record = dataclasses.replace(sample_form_record, submitted_at=naive_utc_now()) with pytest.raises(human_input_service_module.FormSubmittedError): service.ensure_form_active(Form(submitted_record)) @@ -402,7 +403,7 @@ def test_ensure_form_active_errors(sample_form_record, mock_session_factory): # Expired time expired_time_record = dataclasses.replace( - sample_form_record, expiration_time=datetime.utcnow() - timedelta(minutes=1) + sample_form_record, expiration_time=naive_utc_now() - timedelta(minutes=1) ) with pytest.raises(FormExpiredError): service.ensure_form_active(Form(expired_time_record)) @@ -411,7 +412,7 @@ def test_ensure_form_active_errors(sample_form_record, mock_session_factory): def test_ensure_not_submitted_raises(sample_form_record, mock_session_factory): session_factory, _ = mock_session_factory service = HumanInputService(session_factory) - submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow()) + submitted_record = dataclasses.replace(sample_form_record, submitted_at=naive_utc_now()) with pytest.raises(human_input_service_module.FormSubmittedError): service._ensure_not_submitted(Form(submitted_record)) diff --git a/api/tests/unit_tests/services/test_metadata_service.py b/api/tests/unit_tests/services/test_metadata_service.py deleted file mode 100644 index bbdc16d4f8..0000000000 --- a/api/tests/unit_tests/services/test_metadata_service.py +++ /dev/null @@ -1,558 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from datetime import UTC, datetime -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - -from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource -from models.dataset import Dataset -from services.entities.knowledge_entities.knowledge_entities import ( - DocumentMetadataOperation, - MetadataArgs, - MetadataDetail, - MetadataOperationData, -) -from services.metadata_service import MetadataService - - -@dataclass -class _DocumentStub: - id: str - name: str - uploader: str - upload_date: datetime - last_update_date: datetime - data_source_type: str - doc_metadata: dict[str, object] | None - - -@pytest.fixture -def mock_db(mocker: MockerFixture) -> MagicMock: - mocked_db = mocker.patch("services.metadata_service.db") - mocked_db.session = MagicMock() - return mocked_db - - -@pytest.fixture -def mock_redis_client(mocker: MockerFixture) -> MagicMock: - return mocker.patch("services.metadata_service.redis_client") - - -@pytest.fixture -def mock_current_account(mocker: MockerFixture) -> MagicMock: - mock_user = SimpleNamespace(id="user-1") - return mocker.patch("services.metadata_service.current_account_with_tenant", return_value=(mock_user, "tenant-1")) - - -def _build_document(document_id: str, doc_metadata: dict[str, object] | None = None) -> _DocumentStub: - now = datetime(2025, 1, 1, 10, 30, tzinfo=UTC) - return _DocumentStub( - id=document_id, - name=f"doc-{document_id}", - uploader="qa@example.com", - upload_date=now, - last_update_date=now, - data_source_type="upload_file", - doc_metadata=doc_metadata, - ) - - -def _dataset(**kwargs: Any) -> Dataset: - return cast(Dataset, SimpleNamespace(**kwargs)) - - -def test_create_metadata_should_raise_value_error_when_name_exceeds_limit() -> None: - # Arrange - metadata_args = MetadataArgs(type="string", name="x" * 256) - - # Act + Assert - with pytest.raises(ValueError, match="cannot exceed 255"): - MetadataService.create_metadata("dataset-1", metadata_args) - - -def test_create_metadata_should_raise_value_error_when_metadata_name_already_exists( - mock_db: MagicMock, - mock_current_account: MagicMock, -) -> None: - # Arrange - metadata_args = MetadataArgs(type="string", name="priority") - mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() - - # Act + Assert - with pytest.raises(ValueError, match="already exists"): - MetadataService.create_metadata("dataset-1", metadata_args) - - # Assert - mock_current_account.assert_called_once() - - -def test_create_metadata_should_raise_value_error_when_name_collides_with_builtin( - mock_db: MagicMock, mock_current_account: MagicMock -) -> None: - # Arrange - metadata_args = MetadataArgs(type="string", name=BuiltInField.document_name) - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="Built-in fields"): - MetadataService.create_metadata("dataset-1", metadata_args) - - -def test_create_metadata_should_persist_metadata_when_input_is_valid( - mock_db: MagicMock, mock_current_account: MagicMock -) -> None: - # Arrange - metadata_args = MetadataArgs(type="number", name="score") - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Act - result = MetadataService.create_metadata("dataset-1", metadata_args) - - # Assert - assert result.tenant_id == "tenant-1" - assert result.dataset_id == "dataset-1" - assert result.type == "number" - assert result.name == "score" - assert result.created_by == "user-1" - mock_db.session.add.assert_called_once_with(result) - mock_db.session.commit.assert_called_once() - mock_current_account.assert_called_once() - - -def test_update_metadata_name_should_raise_value_error_when_name_exceeds_limit() -> None: - # Arrange - too_long_name = "x" * 256 - - # Act + Assert - with pytest.raises(ValueError, match="cannot exceed 255"): - MetadataService.update_metadata_name("dataset-1", "metadata-1", too_long_name) - - -def test_update_metadata_name_should_raise_value_error_when_duplicate_name_exists( - mock_db: MagicMock, mock_current_account: MagicMock -) -> None: - # Arrange - mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() - - # Act + Assert - with pytest.raises(ValueError, match="already exists"): - MetadataService.update_metadata_name("dataset-1", "metadata-1", "duplicate") - - # Assert - mock_current_account.assert_called_once() - - -def test_update_metadata_name_should_raise_value_error_when_name_collides_with_builtin( - mock_db: MagicMock, - mock_current_account: MagicMock, -) -> None: - # Arrange - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="Built-in fields"): - MetadataService.update_metadata_name("dataset-1", "metadata-1", BuiltInField.source) - - # Assert - mock_current_account.assert_called_once() - - -def test_update_metadata_name_should_update_bound_documents_and_return_metadata( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mock_current_account: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - fixed_now = datetime(2025, 2, 1, 0, 0, tzinfo=UTC) - mocker.patch("services.metadata_service.naive_utc_now", return_value=fixed_now) - - metadata = SimpleNamespace(id="metadata-1", name="old_name", updated_by=None, updated_at=None) - bindings = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")] - query_duplicate = MagicMock() - query_duplicate.filter_by.return_value.first.return_value = None - query_metadata = MagicMock() - query_metadata.filter_by.return_value.first.return_value = metadata - query_bindings = MagicMock() - query_bindings.filter_by.return_value.all.return_value = bindings - mock_db.session.query.side_effect = [query_duplicate, query_metadata, query_bindings] - - doc_1 = _build_document("1", {"old_name": "value", "other": "keep"}) - doc_2 = _build_document("2", None) - mock_get_documents = mocker.patch("services.metadata_service.DocumentService.get_document_by_ids") - mock_get_documents.return_value = [doc_1, doc_2] - - # Act - result = MetadataService.update_metadata_name("dataset-1", "metadata-1", "new_name") - - # Assert - assert result is metadata - assert metadata.name == "new_name" - assert metadata.updated_by == "user-1" - assert metadata.updated_at == fixed_now - assert doc_1.doc_metadata == {"other": "keep", "new_name": "value"} - assert doc_2.doc_metadata == {"new_name": None} - mock_get_documents.assert_called_once_with(["doc-1", "doc-2"]) - mock_db.session.commit.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - mock_current_account.assert_called_once() - - -def test_update_metadata_name_should_return_none_when_metadata_does_not_exist( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mock_current_account: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - mock_logger = mocker.patch("services.metadata_service.logger") - - query_duplicate = MagicMock() - query_duplicate.filter_by.return_value.first.return_value = None - query_metadata = MagicMock() - query_metadata.filter_by.return_value.first.return_value = None - mock_db.session.query.side_effect = [query_duplicate, query_metadata] - - # Act - result = MetadataService.update_metadata_name("dataset-1", "missing-id", "new_name") - - # Assert - assert result is None - mock_logger.exception.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - mock_current_account.assert_called_once() - - -def test_delete_metadata_should_remove_metadata_and_related_document_fields( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - metadata = SimpleNamespace(id="metadata-1", name="obsolete") - bindings = [SimpleNamespace(document_id="doc-1")] - query_metadata = MagicMock() - query_metadata.filter_by.return_value.first.return_value = metadata - query_bindings = MagicMock() - query_bindings.filter_by.return_value.all.return_value = bindings - mock_db.session.query.side_effect = [query_metadata, query_bindings] - - document = _build_document("1", {"obsolete": "legacy", "remaining": "value"}) - mocker.patch("services.metadata_service.DocumentService.get_document_by_ids", return_value=[document]) - - # Act - result = MetadataService.delete_metadata("dataset-1", "metadata-1") - - # Assert - assert result is metadata - assert document.doc_metadata == {"remaining": "value"} - mock_db.session.delete.assert_called_once_with(metadata) - mock_db.session.commit.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - - -def test_delete_metadata_should_return_none_when_metadata_is_missing( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - mock_logger = mocker.patch("services.metadata_service.logger") - - # Act - result = MetadataService.delete_metadata("dataset-1", "missing-id") - - # Assert - assert result is None - mock_logger.exception.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - - -def test_get_built_in_fields_should_return_all_expected_fields() -> None: - # Arrange - expected_names = { - BuiltInField.document_name, - BuiltInField.uploader, - BuiltInField.upload_date, - BuiltInField.last_update_date, - BuiltInField.source, - } - - # Act - result = MetadataService.get_built_in_fields() - - # Assert - assert {item["name"] for item in result} == expected_names - assert [item["type"] for item in result] == ["string", "string", "time", "time", "string"] - - -def test_enable_built_in_field_should_return_immediately_when_already_enabled( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - dataset = _dataset(id="dataset-1", built_in_field_enabled=True) - get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id") - - # Act - MetadataService.enable_built_in_field(dataset) - - # Assert - get_docs.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_enable_built_in_field_should_populate_documents_and_enable_flag( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - dataset = _dataset(id="dataset-1", built_in_field_enabled=False) - doc_1 = _build_document("1", {"custom": "value"}) - doc_2 = _build_document("2", None) - mocker.patch( - "services.metadata_service.DocumentService.get_working_documents_by_dataset_id", - return_value=[doc_1, doc_2], - ) - - # Act - MetadataService.enable_built_in_field(dataset) - - # Assert - assert dataset.built_in_field_enabled is True - assert doc_1.doc_metadata is not None - assert doc_1.doc_metadata[BuiltInField.document_name] == "doc-1" - assert doc_1.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file - assert doc_2.doc_metadata is not None - assert doc_2.doc_metadata[BuiltInField.uploader] == "qa@example.com" - mock_db.session.commit.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - - -def test_disable_built_in_field_should_return_immediately_when_already_disabled( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - dataset = _dataset(id="dataset-1", built_in_field_enabled=False) - get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id") - - # Act - MetadataService.disable_built_in_field(dataset) - - # Assert - get_docs.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_disable_built_in_field_should_remove_builtin_keys_and_disable_flag( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - dataset = _dataset(id="dataset-1", built_in_field_enabled=True) - document = _build_document( - "1", - { - BuiltInField.document_name: "doc", - BuiltInField.uploader: "user", - BuiltInField.upload_date: 1.0, - BuiltInField.last_update_date: 2.0, - BuiltInField.source: MetadataDataSource.upload_file, - "custom": "keep", - }, - ) - mocker.patch( - "services.metadata_service.DocumentService.get_working_documents_by_dataset_id", - return_value=[document], - ) - - # Act - MetadataService.disable_built_in_field(dataset) - - # Assert - assert dataset.built_in_field_enabled is False - assert document.doc_metadata == {"custom": "keep"} - mock_db.session.commit.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - - -def test_update_documents_metadata_should_replace_metadata_and_create_bindings_on_full_update( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mock_current_account: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - dataset = _dataset(id="dataset-1", built_in_field_enabled=False) - document = _build_document("1", {"legacy": "value"}) - mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document) - delete_chain = mock_db.session.query.return_value.filter_by.return_value - delete_chain.delete.return_value = 1 - operation = DocumentMetadataOperation( - document_id="1", - metadata_list=[MetadataDetail(id="meta-1", name="priority", value="high")], - partial_update=False, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Act - MetadataService.update_documents_metadata(dataset, metadata_args) - - # Assert - assert document.doc_metadata == {"priority": "high"} - delete_chain.delete.assert_called_once() - assert mock_db.session.commit.call_count == 1 - mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1") - mock_current_account.assert_called_once() - - -def test_update_documents_metadata_should_skip_existing_binding_and_preserve_existing_fields_on_partial_update( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mock_current_account: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - dataset = _dataset(id="dataset-1", built_in_field_enabled=True) - document = _build_document("1", {"existing": "value"}) - mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document) - mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() - operation = DocumentMetadataOperation( - document_id="1", - metadata_list=[MetadataDetail(id="meta-1", name="new_key", value="new_value")], - partial_update=True, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Act - MetadataService.update_documents_metadata(dataset, metadata_args) - - # Assert - assert document.doc_metadata is not None - assert document.doc_metadata["existing"] == "value" - assert document.doc_metadata["new_key"] == "new_value" - assert document.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file - assert mock_db.session.commit.call_count == 1 - assert mock_db.session.add.call_count == 1 - mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1") - mock_current_account.assert_called_once() - - -def test_update_documents_metadata_should_raise_and_rollback_when_document_not_found( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - dataset = _dataset(id="dataset-1", built_in_field_enabled=False) - mocker.patch("services.metadata_service.DocumentService.get_document", return_value=None) - operation = DocumentMetadataOperation(document_id="404", metadata_list=[], partial_update=True) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Act + Assert - with pytest.raises(ValueError, match="Document not found"): - MetadataService.update_documents_metadata(dataset, metadata_args) - - # Assert - mock_db.session.rollback.assert_called_once() - mock_redis_client.delete.assert_called_once_with("document_metadata_lock_404") - - -@pytest.mark.parametrize( - ("dataset_id", "document_id", "expected_key"), - [ - ("dataset-1", None, "dataset_metadata_lock_dataset-1"), - (None, "doc-1", "document_metadata_lock_doc-1"), - ], -) -def test_knowledge_base_metadata_lock_check_should_set_lock_when_not_already_locked( - dataset_id: str | None, - document_id: str | None, - expected_key: str, - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - - # Act - MetadataService.knowledge_base_metadata_lock_check(dataset_id, document_id) - - # Assert - mock_redis_client.set.assert_called_once_with(expected_key, 1, ex=3600) - - -def test_knowledge_base_metadata_lock_check_should_raise_when_dataset_lock_exists( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = 1 - - # Act + Assert - with pytest.raises(ValueError, match="knowledge base metadata operation is running"): - MetadataService.knowledge_base_metadata_lock_check("dataset-1", None) - - -def test_knowledge_base_metadata_lock_check_should_raise_when_document_lock_exists( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = 1 - - # Act + Assert - with pytest.raises(ValueError, match="document metadata operation is running"): - MetadataService.knowledge_base_metadata_lock_check(None, "doc-1") - - -def test_get_dataset_metadatas_should_exclude_builtin_and_include_binding_counts(mock_db: MagicMock) -> None: - # Arrange - dataset = _dataset( - id="dataset-1", - built_in_field_enabled=True, - doc_metadata=[ - {"id": "meta-1", "name": "priority", "type": "string"}, - {"id": "built-in", "name": "ignored", "type": "string"}, - {"id": "meta-2", "name": "score", "type": "number"}, - ], - ) - count_chain = mock_db.session.query.return_value.filter_by.return_value - count_chain.count.side_effect = [3, 1] - - # Act - result = MetadataService.get_dataset_metadatas(dataset) - - # Assert - assert result["built_in_field_enabled"] is True - assert result["doc_metadata"] == [ - {"id": "meta-1", "name": "priority", "type": "string", "count": 3}, - {"id": "meta-2", "name": "score", "type": "number", "count": 1}, - ] - - -def test_get_dataset_metadatas_should_return_empty_list_when_no_metadata(mock_db: MagicMock) -> None: - # Arrange - dataset = _dataset(id="dataset-1", built_in_field_enabled=False, doc_metadata=None) - - # Act - result = MetadataService.get_dataset_metadatas(dataset) - - # Assert - assert result == {"doc_metadata": [], "built_in_field_enabled": False} - mock_db.session.query.assert_not_called() diff --git a/api/tests/unit_tests/services/test_model_load_balancing_service.py b/api/tests/unit_tests/services/test_model_load_balancing_service.py index 1e898ada11..b43e79dff5 100644 --- a/api/tests/unit_tests/services/test_model_load_balancing_service.py +++ b/api/tests/unit_tests/services/test_model_load_balancing_service.py @@ -6,9 +6,6 @@ from typing import Any, cast from unittest.mock import MagicMock import pytest -from pytest_mock import MockerFixture - -from constants import HIDDEN_VALUE from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.provider_entities import ( @@ -18,6 +15,9 @@ from graphon.model_runtime.entities.provider_entities import ( ModelCredentialSchema, ProviderCredentialSchema, ) +from pytest_mock import MockerFixture + +from constants import HIDDEN_VALUE from models.provider import LoadBalancingModelConfig from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py index 97f3bd6f01..1bd979b9ec 100644 --- a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -1,11 +1,11 @@ import types import pytest - -from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.provider_entities import ConfigurateMethod + +from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration from models.provider import ProviderType from services.model_provider_service import ModelProviderService diff --git a/api/tests/unit_tests/services/test_tag_service.py b/api/tests/unit_tests/services/test_tag_service.py deleted file mode 100644 index b09463b1bc..0000000000 --- a/api/tests/unit_tests/services/test_tag_service.py +++ /dev/null @@ -1,1336 +0,0 @@ -""" -Comprehensive unit tests for TagService. - -This test suite provides complete coverage of tag management operations in Dify, -following TDD principles with the Arrange-Act-Assert pattern. - -The TagService is responsible for managing tags that can be associated with -datasets (knowledge bases) and applications. Tags enable users to organize, -categorize, and filter their content effectively. - -## Test Coverage - -### 1. Tag Retrieval (TestTagServiceRetrieval) -Tests tag listing and filtering: -- Get tags with binding counts -- Filter tags by keyword (case-insensitive) -- Get tags by target ID (apps/datasets) -- Get tags by tag name -- Get target IDs by tag IDs -- Empty results handling - -### 2. Tag CRUD Operations (TestTagServiceCRUD) -Tests tag creation, update, and deletion: -- Create new tags -- Prevent duplicate tag names -- Update tag names -- Update with duplicate name validation -- Delete tags and cascade delete bindings -- Get tag binding counts -- NotFound error handling - -### 3. Tag Binding Operations (TestTagServiceBindings) -Tests tag-to-resource associations: -- Save tag bindings (apps/datasets) -- Prevent duplicate bindings (idempotent) -- Delete tag bindings -- Check target exists validation -- Batch binding operations - -## Testing Approach - -- **Mocking Strategy**: All external dependencies (database, current_user) are mocked - for fast, isolated unit tests -- **Factory Pattern**: TagServiceTestDataFactory provides consistent test data -- **Fixtures**: Mock objects are configured per test method -- **Assertions**: Each test verifies return values and side effects - (database operations, method calls) - -## Key Concepts - -**Tag Types:** -- knowledge: Tags for datasets/knowledge bases -- app: Tags for applications - -**Tag Bindings:** -- Many-to-many relationship between tags and resources -- Each binding links a tag to a specific app or dataset -- Bindings are tenant-scoped for multi-tenancy - -**Validation:** -- Tag names must be unique within tenant and type -- Target resources must exist before binding -- Cascade deletion of bindings when tag is deleted -""" - - -# ============================================================================ -# IMPORTS -# ============================================================================ - -from datetime import UTC, datetime -from unittest.mock import MagicMock, Mock, create_autospec, patch - -import pytest -from werkzeug.exceptions import NotFound - -from models.dataset import Dataset -from models.enums import TagType -from models.model import App, Tag, TagBinding -from services.tag_service import TagService - -# ============================================================================ -# TEST DATA FACTORY -# ============================================================================ - - -class TagServiceTestDataFactory: - """ - Factory for creating test data and mock objects. - - Provides reusable methods to create consistent mock objects for testing - tag-related operations. This factory ensures all test data follows the - same structure and reduces code duplication across tests. - - The factory pattern is used here to: - - Ensure consistent test data creation - - Reduce boilerplate code in individual tests - - Make tests more maintainable and readable - - Centralize mock object configuration - """ - - @staticmethod - def create_tag_mock( - tag_id: str = "tag-123", - name: str = "Test Tag", - tag_type: TagType = TagType.APP, - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """ - Create a mock Tag object. - - This method creates a mock Tag instance with all required attributes - set to sensible defaults. Additional attributes can be passed via - kwargs to customize the mock for specific test scenarios. - - Args: - tag_id: Unique identifier for the tag - name: Tag name (e.g., "Frontend", "Backend", "Data Science") - tag_type: Type of tag ('app' or 'knowledge') - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - (e.g., created_by, created_at, etc.) - - Returns: - Mock Tag object with specified attributes - - Example: - >>> tag = factory.create_tag_mock( - ... tag_id="tag-456", - ... name="Machine Learning", - ... tag_type="knowledge" - ... ) - """ - # Create a mock that matches the Tag model interface - tag = create_autospec(Tag, instance=True) - - # Set core attributes - tag.id = tag_id - tag.name = name - tag.type = tag_type - tag.tenant_id = tenant_id - - # Set default optional attributes - tag.created_by = kwargs.pop("created_by", "user-123") - tag.created_at = kwargs.pop("created_at", datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC)) - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(tag, key, value) - - return tag - - @staticmethod - def create_tag_binding_mock( - binding_id: str = "binding-123", - tag_id: str = "tag-123", - target_id: str = "target-123", - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """ - Create a mock TagBinding object. - - TagBindings represent the many-to-many relationship between tags - and resources (datasets or apps). This method creates a mock - binding with the necessary attributes. - - Args: - binding_id: Unique identifier for the binding - tag_id: Associated tag identifier - target_id: Associated target (app/dataset) identifier - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - (e.g., created_by, etc.) - - Returns: - Mock TagBinding object with specified attributes - - Example: - >>> binding = factory.create_tag_binding_mock( - ... tag_id="tag-456", - ... target_id="dataset-789", - ... tenant_id="tenant-123" - ... ) - """ - # Create a mock that matches the TagBinding model interface - binding = create_autospec(TagBinding, instance=True) - - # Set core attributes - binding.id = binding_id - binding.tag_id = tag_id - binding.target_id = target_id - binding.tenant_id = tenant_id - - # Set default optional attributes - binding.created_by = kwargs.pop("created_by", "user-123") - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(binding, key, value) - - return binding - - @staticmethod - def create_app_mock(app_id: str = "app-123", tenant_id: str = "tenant-123", **kwargs) -> Mock: - """ - Create a mock App object. - - This method creates a mock App instance for testing tag bindings - to applications. Apps are one of the two target types that tags - can be bound to (the other being datasets/knowledge bases). - - Args: - app_id: Unique identifier for the app - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - - Returns: - Mock App object with specified attributes - - Example: - >>> app = factory.create_app_mock( - ... app_id="app-456", - ... name="My Chat App" - ... ) - """ - # Create a mock that matches the App model interface - app = create_autospec(App, instance=True) - - # Set core attributes - app.id = app_id - app.tenant_id = tenant_id - app.name = kwargs.get("name", "Test App") - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(app, key, value) - - return app - - @staticmethod - def create_dataset_mock(dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", **kwargs) -> Mock: - """ - Create a mock Dataset object. - - This method creates a mock Dataset instance for testing tag bindings - to knowledge bases. Datasets (knowledge bases) are one of the two - target types that tags can be bound to (the other being apps). - - Args: - dataset_id: Unique identifier for the dataset - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - - Returns: - Mock Dataset object with specified attributes - - Example: - >>> dataset = factory.create_dataset_mock( - ... dataset_id="dataset-456", - ... name="My Knowledge Base" - ... ) - """ - # Create a mock that matches the Dataset model interface - dataset = create_autospec(Dataset, instance=True) - - # Set core attributes - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.name = kwargs.pop("name", "Test Dataset") - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(dataset, key, value) - - return dataset - - -# ============================================================================ -# PYTEST FIXTURES -# ============================================================================ - - -@pytest.fixture -def factory(): - """ - Provide the test data factory to all tests. - - This fixture makes the TagServiceTestDataFactory available to all test - methods, allowing them to create consistent mock objects easily. - - Returns: - TagServiceTestDataFactory class - """ - return TagServiceTestDataFactory - - -# ============================================================================ -# TAG RETRIEVAL TESTS -# ============================================================================ - - -class TestTagServiceRetrieval: - """ - Test tag retrieval operations. - - This test class covers all methods related to retrieving and querying - tags from the system. These operations are read-only and do not modify - the database state. - - Methods tested: - - get_tags: Retrieve tags with optional keyword filtering - - get_target_ids_by_tag_ids: Get target IDs (datasets/apps) by tag IDs - - get_tag_by_tag_name: Find tags by exact name match - - get_tags_by_target_id: Get all tags bound to a specific target - """ - - @patch("services.tag_service.db.session") - def test_get_tags_with_binding_counts(self, mock_db_session, factory): - """ - Test retrieving tags with their binding counts. - - This test verifies that the get_tags method correctly retrieves - a list of tags along with the count of how many resources - (datasets/apps) are bound to each tag. - - The method should: - - Query tags filtered by type and tenant - - Include binding counts via a LEFT OUTER JOIN - - Return results ordered by creation date (newest first) - - Expected behavior: - - Returns a list of tuples containing (id, type, name, binding_count) - - Each tag includes its binding count - - Results are ordered by creation date descending - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - - # Mock query results: tuples of (tag_id, type, name, binding_count) - # This simulates the SQL query result with aggregated binding counts - mock_results = [ - ("tag-1", "app", "Frontend", 5), # Frontend tag with 5 bindings - ("tag-2", "app", "Backend", 3), # Backend tag with 3 bindings - ("tag-3", "app", "API", 0), # API tag with no bindings - ] - - # Configure mock database session and query chain - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.outerjoin.return_value = mock_query # LEFT OUTER JOIN with TagBinding - mock_query.where.return_value = mock_query # WHERE clause for filtering - mock_query.group_by.return_value = mock_query # GROUP BY for aggregation - mock_query.order_by.return_value = mock_query # ORDER BY for sorting - mock_query.all.return_value = mock_results # Final result - - # Act - # Execute the method under test - results = TagService.get_tags(tag_type=tag_type, current_tenant_id=tenant_id) - - # Assert - # Verify the results match expectations - assert len(results) == 3, "Should return 3 tags" - - # Verify each tag's data structure - assert results[0] == ("tag-1", "app", "Frontend", 5), "First tag should match" - assert results[1] == ("tag-2", "app", "Backend", 3), "Second tag should match" - assert results[2] == ("tag-3", "app", "API", 0), "Third tag should match" - - # Verify database query was called - mock_db_session.query.assert_called_once() - - @patch("services.tag_service.db.session") - def test_get_tags_with_keyword_filter(self, mock_db_session, factory): - """ - Test retrieving tags filtered by keyword (case-insensitive). - - This test verifies that the get_tags method correctly filters tags - by keyword when a keyword parameter is provided. The filtering - should be case-insensitive and support partial matches. - - The method should: - - Apply an additional WHERE clause when keyword is provided - - Use ILIKE for case-insensitive pattern matching - - Support partial matches (e.g., "data" matches "Database" and "Data Science") - - Expected behavior: - - Returns only tags whose names contain the keyword - - Matching is case-insensitive - - Partial matches are supported - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "knowledge" - keyword = "data" - - # Mock query results filtered by keyword - mock_results = [ - ("tag-1", "knowledge", "Database", 2), - ("tag-2", "knowledge", "Data Science", 4), - ] - - # Configure mock database session and query chain - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.group_by.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = mock_results - - # Act - # Execute the method with keyword filter - results = TagService.get_tags(tag_type=tag_type, current_tenant_id=tenant_id, keyword=keyword) - - # Assert - # Verify filtered results - assert len(results) == 2, "Should return 2 matching tags" - - # Verify keyword filter was applied - # The where() method should be called at least twice: - # 1. Initial WHERE clause for type and tenant - # 2. Additional WHERE clause for keyword filtering - assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause" - - @patch("services.tag_service.db.session") - def test_get_target_ids_by_tag_ids(self, mock_db_session, factory): - """ - Test retrieving target IDs by tag IDs. - - This test verifies that the get_target_ids_by_tag_ids method correctly - retrieves all target IDs (dataset/app IDs) that are bound to the - specified tags. This is useful for filtering datasets or apps by tags. - - The method should: - - First validate and filter tags by type and tenant - - Then find all bindings for those tags - - Return the target IDs from those bindings - - Expected behavior: - - Returns a list of target IDs (strings) - - Only includes targets bound to valid tags - - Respects tenant and type filtering - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - tag_ids = ["tag-1", "tag-2"] - - # Create mock tag objects - tags = [ - factory.create_tag_mock(tag_id="tag-1", tenant_id=tenant_id, tag_type=tag_type), - factory.create_tag_mock(tag_id="tag-2", tenant_id=tenant_id, tag_type=tag_type), - ] - - # Mock target IDs that are bound to these tags - target_ids = ["app-1", "app-2", "app-3"] - - # Mock tag query (first scalars call) - mock_scalars_tags = MagicMock() - mock_scalars_tags.all.return_value = tags - - # Mock binding query (second scalars call) - mock_scalars_bindings = MagicMock() - mock_scalars_bindings.all.return_value = target_ids - - # Configure side_effect to return different mocks for each scalars() call - mock_db_session.scalars.side_effect = [mock_scalars_tags, mock_scalars_bindings] - - # Act - # Execute the method under test - results = TagService.get_target_ids_by_tag_ids(tag_type=tag_type, current_tenant_id=tenant_id, tag_ids=tag_ids) - - # Assert - # Verify results match expected target IDs - assert results == target_ids, "Should return all target IDs bound to tags" - - # Verify both queries were executed - assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query" - - @patch("services.tag_service.db.session") - def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory): - """ - Test that empty tag_ids returns empty list. - - This test verifies the edge case handling when an empty list of - tag IDs is provided. The method should return early without - executing any database queries. - - Expected behavior: - - Returns empty list immediately - - Does not execute any database queries - - Handles empty input gracefully - """ - # Arrange - # Set up test parameters with empty tag IDs - tenant_id = "tenant-123" - tag_type = "app" - - # Act - # Execute the method with empty tag IDs list - results = TagService.get_target_ids_by_tag_ids(tag_type=tag_type, current_tenant_id=tenant_id, tag_ids=[]) - - # Assert - # Verify empty result and no database queries - assert results == [], "Should return empty list for empty input" - mock_db_session.scalars.assert_not_called(), "Should not query database for empty input" - - @patch("services.tag_service.db.session") - def test_get_tag_by_tag_name(self, mock_db_session, factory): - """ - Test retrieving tags by name. - - This test verifies that the get_tag_by_tag_name method correctly - finds tags by their exact name. This is used for duplicate name - checking and tag lookup operations. - - The method should: - - Perform exact name matching (case-sensitive) - - Filter by type and tenant - - Return a list of matching tags (usually 0 or 1) - - Expected behavior: - - Returns list of tags with matching name - - Respects type and tenant filtering - - Returns empty list if no matches found - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - tag_name = "Production" - - # Create mock tag with matching name - tags = [factory.create_tag_mock(name=tag_name, tag_type=tag_type, tenant_id=tenant_id)] - - # Configure mock database session - mock_scalars = MagicMock() - mock_scalars.all.return_value = tags - mock_db_session.scalars.return_value = mock_scalars - - # Act - # Execute the method under test - results = TagService.get_tag_by_tag_name(tag_type=tag_type, current_tenant_id=tenant_id, tag_name=tag_name) - - # Assert - # Verify tag was found - assert len(results) == 1, "Should find exactly one tag" - assert results[0].name == tag_name, "Tag name should match" - - @patch("services.tag_service.db.session") - def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory): - """ - Test that missing tag_type or tag_name returns empty list. - - This test verifies the input validation for the get_tag_by_tag_name - method. When either tag_type or tag_name is empty or missing, - the method should return early without querying the database. - - Expected behavior: - - Returns empty list for empty tag_type - - Returns empty list for empty tag_name - - Does not execute database queries for invalid input - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - - # Act & Assert - # Test with empty tag_type - assert TagService.get_tag_by_tag_name("", tenant_id, "name") == [], "Should return empty for empty type" - - # Test with empty tag_name - assert TagService.get_tag_by_tag_name("app", tenant_id, "") == [], "Should return empty for empty name" - - # Verify no database queries were executed - mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input" - - @patch("services.tag_service.db.session") - def test_get_tags_by_target_id(self, mock_db_session, factory): - """ - Test retrieving tags associated with a specific target. - - This test verifies that the get_tags_by_target_id method correctly - retrieves all tags that are bound to a specific target (dataset or app). - This is useful for displaying tags associated with a resource. - - The method should: - - Join Tag and TagBinding tables - - Filter by target_id, tenant, and type - - Return all tags bound to the target - - Expected behavior: - - Returns list of Tag objects bound to the target - - Respects tenant and type filtering - - Returns empty list if no tags are bound - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - target_id = "app-123" - - # Create mock tags that are bound to the target - tags = [ - factory.create_tag_mock(tag_id="tag-1", name="Frontend"), - factory.create_tag_mock(tag_id="tag-2", name="Production"), - ] - - # Configure mock database session and query chain - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.join.return_value = mock_query # JOIN with TagBinding - mock_query.where.return_value = mock_query # WHERE clause for filtering - mock_query.all.return_value = tags # Final result - - # Act - # Execute the method under test - results = TagService.get_tags_by_target_id(tag_type=tag_type, current_tenant_id=tenant_id, target_id=target_id) - - # Assert - # Verify tags were retrieved - assert len(results) == 2, "Should return 2 tags bound to target" - - # Verify tag names - assert results[0].name == "Frontend", "First tag name should match" - assert results[1].name == "Production", "Second tag name should match" - - -# ============================================================================ -# TAG CRUD OPERATIONS TESTS -# ============================================================================ - - -class TestTagServiceCRUD: - """ - Test tag CRUD operations. - - This test class covers all Create, Read, Update, and Delete operations - for tags. These operations modify the database state and require proper - transaction handling and validation. - - Methods tested: - - save_tags: Create new tags - - update_tags: Update existing tag names - - delete_tag: Delete tags and cascade delete bindings - - get_tag_binding_count: Get count of bindings for a tag - """ - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session") - @patch("services.tag_service.uuid.uuid4", autospec=True) - def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): - """ - Test creating a new tag. - - This test verifies that the save_tags method correctly creates a new - tag in the database with all required attributes. The method should - validate uniqueness, generate a UUID, and persist the tag. - - The method should: - - Check for duplicate tag names (via get_tag_by_tag_name) - - Generate a unique UUID for the tag ID - - Set user and tenant information from current_user - - Persist the tag to the database - - Commit the transaction - - Expected behavior: - - Creates tag with correct attributes - - Assigns UUID to tag ID - - Sets created_by from current_user - - Sets tenant_id from current_user - - Commits to database - """ - # Arrange - # Configure mock current_user - mock_current_user.id = "user-123" - mock_current_user.current_tenant_id = "tenant-123" - - # Mock UUID generation - mock_uuid.return_value = "new-tag-id" - - # Mock no existing tag (duplicate check passes) - mock_get_tag_by_name.return_value = [] - - # Prepare tag creation arguments - args = {"name": "New Tag", "type": "app"} - - # Act - # Execute the method under test - result = TagService.save_tags(args) - - # Assert - # Verify tag was added to database session - mock_db_session.add.assert_called_once(), "Should add tag to session" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - # Verify tag attributes - added_tag = mock_db_session.add.call_args[0][0] - assert added_tag.name == "New Tag", "Tag name should match" - assert added_tag.type == TagType.APP, "Tag type should match" - assert added_tag.created_by == "user-123", "Created by should match current user" - assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - def test_save_tags_raises_error_for_duplicate_name(self, mock_get_tag_by_name, mock_current_user, factory): - """ - Test that creating a tag with duplicate name raises ValueError. - - This test verifies that the save_tags method correctly prevents - duplicate tag names within the same tenant and type. Tag names - must be unique per tenant and type combination. - - Expected behavior: - - Raises ValueError when duplicate name is detected - - Error message indicates "Tag name already exists" - - Does not create the tag - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Mock existing tag with same name (duplicate detected) - existing_tag = factory.create_tag_mock(name="Existing Tag") - mock_get_tag_by_name.return_value = [existing_tag] - - # Prepare tag creation arguments with duplicate name - args = {"name": "Existing Tag", "type": "app"} - - # Act & Assert - # Verify ValueError is raised for duplicate name - with pytest.raises(ValueError, match="Tag name already exists"): - TagService.save_tags(args) - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session") - def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): - """ - Test updating a tag name. - - This test verifies that the update_tags method correctly updates - an existing tag's name while preserving other attributes. The method - should validate uniqueness of the new name and ensure the tag exists. - - The method should: - - Check for duplicate tag names (excluding the current tag) - - Find the tag by ID - - Update the tag name - - Commit the transaction - - Expected behavior: - - Updates tag name successfully - - Preserves other tag attributes - - Commits to database - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Mock no duplicate name (update check passes) - mock_get_tag_by_name.return_value = [] - - # Create mock tag to be updated - tag = factory.create_tag_mock(tag_id="tag-123", name="Old Name") - - # Configure mock database session to return the tag - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = tag - - # Prepare update arguments - args = {"name": "New Name", "type": "app"} - - # Act - # Execute the method under test - result = TagService.update_tags(args, tag_id="tag-123") - - # Assert - # Verify tag name was updated - assert tag.name == "New Name", "Tag name should be updated" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session") - def test_update_tags_raises_error_for_duplicate_name( - self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory - ): - """ - Test that updating to a duplicate name raises ValueError. - - This test verifies that the update_tags method correctly prevents - updating a tag to a name that already exists for another tag - within the same tenant and type. - - Expected behavior: - - Raises ValueError when duplicate name is detected - - Error message indicates "Tag name already exists" - - Does not update the tag - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Mock existing tag with the duplicate name - existing_tag = factory.create_tag_mock(name="Duplicate Name") - mock_get_tag_by_name.return_value = [existing_tag] - - # Prepare update arguments with duplicate name - args = {"name": "Duplicate Name", "type": "app"} - - # Act & Assert - # Verify ValueError is raised for duplicate name - with pytest.raises(ValueError, match="Tag name already exists"): - TagService.update_tags(args, tag_id="tag-123") - - @patch("services.tag_service.db.session") - def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory): - """ - Test that updating a non-existent tag raises NotFound. - - This test verifies that the update_tags method correctly handles - the case when attempting to update a tag that does not exist. - This prevents silent failures and provides clear error feedback. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Tag not found" - - Does not attempt to update or commit - """ - # Arrange - # Configure mock database session to return None (tag not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock duplicate check and current_user - with patch("services.tag_service.TagService.get_tag_by_tag_name", return_value=[], autospec=True): - with patch("services.tag_service.current_user", autospec=True) as mock_user: - mock_user.current_tenant_id = "tenant-123" - args = {"name": "New Name", "type": "app"} - - # Act & Assert - # Verify NotFound is raised for non-existent tag - with pytest.raises(NotFound, match="Tag not found"): - TagService.update_tags(args, tag_id="nonexistent") - - @patch("services.tag_service.db.session") - def test_get_tag_binding_count(self, mock_db_session, factory): - """ - Test getting the count of bindings for a tag. - - This test verifies that the get_tag_binding_count method correctly - counts how many resources (datasets/apps) are bound to a specific tag. - This is useful for displaying tag usage statistics. - - The method should: - - Query TagBinding table filtered by tag_id - - Return the count of matching bindings - - Expected behavior: - - Returns integer count of bindings - - Returns 0 for tags with no bindings - """ - # Arrange - # Set up test parameters - tag_id = "tag-123" - expected_count = 5 - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.count.return_value = expected_count - - # Act - # Execute the method under test - result = TagService.get_tag_binding_count(tag_id) - - # Assert - # Verify count matches expectation - assert result == expected_count, "Binding count should match" - - @patch("services.tag_service.db.session") - def test_delete_tag(self, mock_db_session, factory): - """ - Test deleting a tag and its bindings. - - This test verifies that the delete_tag method correctly deletes - a tag along with all its associated bindings (cascade delete). - This ensures data integrity and prevents orphaned bindings. - - The method should: - - Find the tag by ID - - Delete the tag - - Find all bindings for the tag - - Delete all bindings (cascade delete) - - Commit the transaction - - Expected behavior: - - Deletes tag from database - - Deletes all associated bindings - - Commits transaction - """ - # Arrange - # Set up test parameters - tag_id = "tag-123" - - # Create mock tag to be deleted - tag = factory.create_tag_mock(tag_id=tag_id) - - # Create mock bindings that will be cascade deleted - bindings = [factory.create_tag_binding_mock(binding_id=f"binding-{i}", tag_id=tag_id) for i in range(3)] - - # Configure mock database session for tag query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = tag - - # Configure mock database session for bindings query - mock_scalars = MagicMock() - mock_scalars.all.return_value = bindings - mock_db_session.scalars.return_value = mock_scalars - - # Act - # Execute the method under test - TagService.delete_tag(tag_id) - - # Assert - # Verify tag and bindings were deleted - mock_db_session.delete.assert_called(), "Should call delete method" - - # Verify delete was called 4 times (1 tag + 3 bindings) - assert mock_db_session.delete.call_count == 4, "Should delete tag and all bindings" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.db.session") - def test_delete_tag_raises_not_found(self, mock_db_session, factory): - """ - Test that deleting a non-existent tag raises NotFound. - - This test verifies that the delete_tag method correctly handles - the case when attempting to delete a tag that does not exist. - This prevents silent failures and provides clear error feedback. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Tag not found" - - Does not attempt to delete or commit - """ - # Arrange - # Configure mock database session to return None (tag not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act & Assert - # Verify NotFound is raised for non-existent tag - with pytest.raises(NotFound, match="Tag not found"): - TagService.delete_tag("nonexistent") - - -# ============================================================================ -# TAG BINDING OPERATIONS TESTS -# ============================================================================ - - -class TestTagServiceBindings: - """ - Test tag binding operations. - - This test class covers all operations related to binding tags to - resources (datasets and apps). Tag bindings create the many-to-many - relationship between tags and resources. - - Methods tested: - - save_tag_binding: Create bindings between tags and targets - - delete_tag_binding: Remove bindings between tags and targets - - check_target_exists: Validate target (dataset/app) existence - """ - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session") - def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory): - """ - Test creating tag bindings. - - This test verifies that the save_tag_binding method correctly - creates bindings between tags and a target resource (dataset or app). - The method supports batch binding of multiple tags to a single target. - - The method should: - - Validate target exists (via check_target_exists) - - Check for existing bindings to avoid duplicates - - Create new bindings for tags that aren't already bound - - Commit the transaction - - Expected behavior: - - Validates target exists - - Creates bindings for each tag in tag_ids - - Skips tags that are already bound (idempotent) - - Commits transaction - """ - # Arrange - # Configure mock current_user - mock_current_user.id = "user-123" - mock_current_user.current_tenant_id = "tenant-123" - - # Configure mock database session (no existing bindings) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # No existing bindings - - # Prepare binding arguments (batch binding) - args = {"type": "app", "target_id": "app-123", "tag_ids": ["tag-1", "tag-2"]} - - # Act - # Execute the method under test - TagService.save_tag_binding(args) - - # Assert - # Verify target existence was checked - mock_check_target.assert_called_once_with("app", "app-123"), "Should validate target exists" - - # Verify bindings were created (2 bindings for 2 tags) - assert mock_db_session.add.call_count == 2, "Should create 2 bindings" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session") - def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory): - """ - Test that saving duplicate bindings is idempotent. - - This test verifies that the save_tag_binding method correctly handles - the case when attempting to create a binding that already exists. - The method should skip existing bindings and not create duplicates, - making the operation idempotent. - - Expected behavior: - - Checks for existing bindings - - Skips tags that are already bound - - Does not create duplicate bindings - - Still commits transaction - """ - # Arrange - # Configure mock current_user - mock_current_user.id = "user-123" - mock_current_user.current_tenant_id = "tenant-123" - - # Mock existing binding (duplicate detected) - existing_binding = factory.create_tag_binding_mock() - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = existing_binding # Binding already exists - - # Prepare binding arguments - args = {"type": "app", "target_id": "app-123", "tag_ids": ["tag-1"]} - - # Act - # Execute the method under test - TagService.save_tag_binding(args) - - # Assert - # Verify no new binding was added (idempotent) - mock_db_session.add.assert_not_called(), "Should not create duplicate binding" - - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session") - def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory): - """ - Test deleting a tag binding. - - This test verifies that the delete_tag_binding method correctly - removes a binding between a tag and a target resource. This - operation should be safe even if the binding doesn't exist. - - The method should: - - Validate target exists (via check_target_exists) - - Find the binding by tag_id and target_id - - Delete the binding if it exists - - Commit the transaction - - Expected behavior: - - Validates target exists - - Deletes the binding - - Commits transaction - """ - # Arrange - # Create mock binding to be deleted - binding = factory.create_tag_binding_mock() - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = binding - - # Prepare delete arguments - args = {"type": "app", "target_id": "app-123", "tag_id": "tag-1"} - - # Act - # Execute the method under test - TagService.delete_tag_binding(args) - - # Assert - # Verify target existence was checked - mock_check_target.assert_called_once_with("app", "app-123"), "Should validate target exists" - - # Verify binding was deleted - mock_db_session.delete.assert_called_once_with(binding), "Should delete the binding" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session") - def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory): - """ - Test that deleting a non-existent binding is a no-op. - - This test verifies that the delete_tag_binding method correctly - handles the case when attempting to delete a binding that doesn't - exist. The method should not raise an error and should not commit - if there's nothing to delete. - - Expected behavior: - - Validates target exists - - Does not raise error for non-existent binding - - Does not call delete or commit if binding doesn't exist - """ - # Arrange - # Configure mock database session (binding not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # Binding doesn't exist - - # Prepare delete arguments - args = {"type": "app", "target_id": "app-123", "tag_id": "tag-1"} - - # Act - # Execute the method under test - TagService.delete_tag_binding(args) - - # Assert - # Verify no delete operation was attempted - mock_db_session.delete.assert_not_called(), "Should not delete if binding doesn't exist" - - # Verify no commit was made (nothing changed) - mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session") - def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory): - """ - Test validating that a dataset target exists. - - This test verifies that the check_target_exists method correctly - validates the existence of a dataset (knowledge base) when the - target type is "knowledge". This validation ensures bindings - are only created for valid resources. - - The method should: - - Query Dataset table filtered by tenant and ID - - Raise NotFound if dataset doesn't exist - - Return normally if dataset exists - - Expected behavior: - - No exception raised when dataset exists - - Database query is executed - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Create mock dataset - dataset = factory.create_dataset_mock() - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = dataset # Dataset exists - - # Act - # Execute the method under test - TagService.check_target_exists("knowledge", "dataset-123") - - # Assert - # Verify no exception was raised and query was executed - mock_db_session.query.assert_called_once(), "Should query database for dataset" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session") - def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory): - """ - Test validating that an app target exists. - - This test verifies that the check_target_exists method correctly - validates the existence of an application when the target type is - "app". This validation ensures bindings are only created for valid - resources. - - The method should: - - Query App table filtered by tenant and ID - - Raise NotFound if app doesn't exist - - Return normally if app exists - - Expected behavior: - - No exception raised when app exists - - Database query is executed - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Create mock app - app = factory.create_app_mock() - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = app # App exists - - # Act - # Execute the method under test - TagService.check_target_exists("app", "app-123") - - # Assert - # Verify no exception was raised and query was executed - mock_db_session.query.assert_called_once(), "Should query database for app" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session") - def test_check_target_exists_raises_not_found_for_missing_dataset( - self, mock_db_session, mock_current_user, factory - ): - """ - Test that missing dataset raises NotFound. - - This test verifies that the check_target_exists method correctly - raises a NotFound exception when attempting to validate a dataset - that doesn't exist. This prevents creating bindings for invalid - resources. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Dataset not found" - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Configure mock database session (dataset not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # Dataset doesn't exist - - # Act & Assert - # Verify NotFound is raised for non-existent dataset - with pytest.raises(NotFound, match="Dataset not found"): - TagService.check_target_exists("knowledge", "nonexistent") - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session") - def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory): - """ - Test that missing app raises NotFound. - - This test verifies that the check_target_exists method correctly - raises a NotFound exception when attempting to validate an app - that doesn't exist. This prevents creating bindings for invalid - resources. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "App not found" - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Configure mock database session (app not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # App doesn't exist - - # Act & Assert - # Verify NotFound is raised for non-existent app - with pytest.raises(NotFound, match="App not found"): - TagService.check_target_exists("app", "nonexistent") - - def test_check_target_exists_raises_not_found_for_invalid_type(self, factory): - """ - Test that invalid binding type raises NotFound. - - This test verifies that the check_target_exists method correctly - raises a NotFound exception when an invalid target type is provided. - Only "knowledge" (for datasets) and "app" are valid target types. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Invalid binding type" - """ - # Act & Assert - # Verify NotFound is raised for invalid target type - with pytest.raises(NotFound, match="Invalid binding type"): - TagService.check_target_exists("invalid_type", "target-123") diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index 2fe6161785..9c23135225 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -16,9 +16,7 @@ from typing import Any from uuid import uuid4 import pytest - -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File +from graphon.file import File, FileTransferMethod, FileType from graphon.variables.segments import ( ArrayFileSegment, ArrayNumberSegment, @@ -30,6 +28,7 @@ from graphon.variables.segments import ( ObjectSegment, StringSegment, ) + from services.variable_truncator import ( DummyVariableTruncator, MaxDepthExceededError, diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index 239cc83518..a62c9f4555 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -13,10 +13,10 @@ from datetime import datetime from unittest.mock import MagicMock, create_autospec, patch import pytest +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker -from graphon.enums import WorkflowExecutionStatus from models.workflow import WorkflowPause from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index da606c8329..cd71981bcf 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -15,7 +15,6 @@ from typing import Any, cast from unittest.mock import ANY, MagicMock, patch import pytest - from graphon.entities import WorkflowNodeExecution from graphon.enums import ( BuiltinNodeTypes, @@ -29,6 +28,7 @@ from graphon.model_runtime.entities.model_entities import ModelType from graphon.node_events import NodeRunResult from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig from graphon.variables.input_entities import VariableEntityType + from libs.datetime_utils import naive_utc_now from models.human_input import RecipientType from models.model import App, AppMode diff --git a/api/tests/unit_tests/services/test_workspace_service.py b/api/tests/unit_tests/services/test_workspace_service.py deleted file mode 100644 index 9bfd7eb2c5..0000000000 --- a/api/tests/unit_tests/services/test_workspace_service.py +++ /dev/null @@ -1,576 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - -from models.account import Tenant - -# --------------------------------------------------------------------------- -# Constants used throughout the tests -# --------------------------------------------------------------------------- - -TENANT_ID = "tenant-abc" -ACCOUNT_ID = "account-xyz" -FILES_BASE_URL = "https://files.example.com" - -DB_PATH = "services.workspace_service.db" -FEATURE_SERVICE_PATH = "services.workspace_service.FeatureService.get_features" -TENANT_SERVICE_PATH = "services.workspace_service.TenantService.has_roles" -DIFY_CONFIG_PATH = "services.workspace_service.dify_config" -CURRENT_USER_PATH = "services.workspace_service.current_user" -CREDIT_POOL_SERVICE_PATH = "services.credit_pool_service.CreditPoolService.get_pool" - - -# --------------------------------------------------------------------------- -# Helpers / factories -# --------------------------------------------------------------------------- - - -def _make_tenant( - tenant_id: str = TENANT_ID, - name: str = "My Workspace", - plan: str = "sandbox", - status: str = "active", - custom_config: dict | None = None, -) -> Tenant: - """Create a minimal Tenant-like namespace.""" - return cast( - Tenant, - SimpleNamespace( - id=tenant_id, - name=name, - plan=plan, - status=status, - created_at="2024-01-01T00:00:00Z", - custom_config_dict=custom_config or {}, - ), - ) - - -def _make_feature( - can_replace_logo: bool = False, - next_credit_reset_date: str | None = None, - billing_plan: str = "sandbox", -) -> MagicMock: - """Create a feature namespace matching what FeatureService.get_features returns.""" - feature = MagicMock() - feature.can_replace_logo = can_replace_logo - feature.next_credit_reset_date = next_credit_reset_date - feature.billing.subscription.plan = billing_plan - return feature - - -def _make_pool(quota_limit: int, quota_used: int) -> MagicMock: - pool = MagicMock() - pool.quota_limit = quota_limit - pool.quota_used = quota_used - return pool - - -def _make_tenant_account_join(role: str = "normal") -> SimpleNamespace: - return SimpleNamespace(role=role) - - -def _tenant_info(result: object) -> dict[str, Any] | None: - return cast(dict[str, Any] | None, result) - - -# --------------------------------------------------------------------------- -# Shared fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def mock_current_user() -> SimpleNamespace: - """Return a lightweight current_user stand-in.""" - return SimpleNamespace(id=ACCOUNT_ID) - - -@pytest.fixture -def basic_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict: - """ - Patch the common external boundaries used by WorkspaceService.get_tenant_info. - - Returns a dict of named mocks so individual tests can customise them. - """ - mocker.patch(CURRENT_USER_PATH, mock_current_user) - - mock_db_session = mocker.patch(f"{DB_PATH}.session") - mock_query_chain = MagicMock() - mock_db_session.query.return_value = mock_query_chain - mock_query_chain.where.return_value = mock_query_chain - mock_query_chain.first.return_value = _make_tenant_account_join(role="owner") - - mock_feature = mocker.patch(FEATURE_SERVICE_PATH, return_value=_make_feature()) - mock_has_roles = mocker.patch(TENANT_SERVICE_PATH, return_value=False) - mock_config = mocker.patch(DIFY_CONFIG_PATH) - mock_config.EDITION = "SELF_HOSTED" - mock_config.FILES_URL = FILES_BASE_URL - - return { - "db_session": mock_db_session, - "query_chain": mock_query_chain, - "get_features": mock_feature, - "has_roles": mock_has_roles, - "config": mock_config, - } - - -# --------------------------------------------------------------------------- -# 1. None Tenant Handling -# --------------------------------------------------------------------------- - - -def test_get_tenant_info_should_return_none_when_tenant_is_none() -> None: - """get_tenant_info should short-circuit and return None for a falsy tenant.""" - from services.workspace_service import WorkspaceService - - # Arrange - tenant = None - - # Act - result = WorkspaceService.get_tenant_info(cast(Tenant, tenant)) - - # Assert - assert result is None - - -def test_get_tenant_info_should_return_none_when_tenant_is_falsy() -> None: - """get_tenant_info treats any falsy value as absent (e.g. empty string, 0).""" - from services.workspace_service import WorkspaceService - - # Arrange / Act / Assert - assert WorkspaceService.get_tenant_info("") is None # type: ignore[arg-type] - - -# --------------------------------------------------------------------------- -# 2. Basic Tenant Info — happy path -# --------------------------------------------------------------------------- - - -def test_get_tenant_info_should_return_base_fields( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """get_tenant_info should always return the six base scalar fields.""" - from services.workspace_service import WorkspaceService - - # Arrange - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["id"] == TENANT_ID - assert result["name"] == "My Workspace" - assert result["plan"] == "sandbox" - assert result["status"] == "active" - assert result["created_at"] == "2024-01-01T00:00:00Z" - assert result["trial_end_reason"] is None - - -def test_get_tenant_info_should_populate_role_from_tenant_account_join( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """The 'role' field should be taken from TenantAccountJoin, not the default.""" - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["query_chain"].first.return_value = _make_tenant_account_join(role="admin") - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["role"] == "admin" - - -def test_get_tenant_info_should_raise_assertion_when_tenant_account_join_missing( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """ - The service asserts that TenantAccountJoin exists. - Missing join should raise AssertionError. - """ - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["query_chain"].first.return_value = None - tenant = _make_tenant() - - # Act + Assert - with pytest.raises(AssertionError, match="TenantAccountJoin not found"): - WorkspaceService.get_tenant_info(tenant) - - -# --------------------------------------------------------------------------- -# 3. Logo Customisation -# --------------------------------------------------------------------------- - - -def test_get_tenant_info_should_include_custom_config_when_logo_allowed_and_admin( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """custom_config block should appear for OWNER/ADMIN when can_replace_logo is True.""" - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) - basic_mocks["has_roles"].return_value = True - tenant = _make_tenant( - custom_config={ - "replace_webapp_logo": True, - "remove_webapp_brand": True, - } - ) - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert "custom_config" in result - assert result["custom_config"]["remove_webapp_brand"] is True - expected_logo_url = f"{FILES_BASE_URL}/files/workspaces/{TENANT_ID}/webapp-logo" - assert result["custom_config"]["replace_webapp_logo"] == expected_logo_url - - -def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """replace_webapp_logo should be None when custom_config_dict does not have the key.""" - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) - basic_mocks["has_roles"].return_value = True - tenant = _make_tenant(custom_config={}) # no replace_webapp_logo key - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["custom_config"]["replace_webapp_logo"] is None - - -def test_get_tenant_info_should_not_include_custom_config_when_logo_not_allowed( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """custom_config should be absent when can_replace_logo is False.""" - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=False) - basic_mocks["has_roles"].return_value = True - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert "custom_config" not in result - - -def test_get_tenant_info_should_not_include_custom_config_when_user_not_admin( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """custom_config block is gated on OWNER or ADMIN role.""" - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) - basic_mocks["has_roles"].return_value = False # regular member - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert "custom_config" not in result - - -def test_get_tenant_info_should_use_files_url_for_logo_url( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """The logo URL should use dify_config.FILES_URL as the base.""" - from services.workspace_service import WorkspaceService - - # Arrange - custom_base = "https://cdn.mycompany.io" - basic_mocks["config"].FILES_URL = custom_base - basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) - basic_mocks["has_roles"].return_value = True - tenant = _make_tenant(custom_config={"replace_webapp_logo": True}) - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base) - - -# --------------------------------------------------------------------------- -# 4. Cloud-Edition Credit Features -# --------------------------------------------------------------------------- - -CLOUD_BILLING_PLAN_NON_SANDBOX = "professional" # any plan that is not SANDBOX - - -@pytest.fixture -def cloud_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict: - """Patches for CLOUD edition tests, billing plan = professional by default.""" - mocker.patch(CURRENT_USER_PATH, mock_current_user) - - mock_db_session = mocker.patch(f"{DB_PATH}.session") - mock_query_chain = MagicMock() - mock_db_session.query.return_value = mock_query_chain - mock_query_chain.where.return_value = mock_query_chain - mock_query_chain.first.return_value = _make_tenant_account_join(role="owner") - - mock_feature = mocker.patch( - FEATURE_SERVICE_PATH, - return_value=_make_feature( - can_replace_logo=False, - next_credit_reset_date="2025-02-01", - billing_plan=CLOUD_BILLING_PLAN_NON_SANDBOX, - ), - ) - mocker.patch(TENANT_SERVICE_PATH, return_value=False) - mock_config = mocker.patch(DIFY_CONFIG_PATH) - mock_config.EDITION = "CLOUD" - mock_config.FILES_URL = FILES_BASE_URL - - return { - "db_session": mock_db_session, - "query_chain": mock_query_chain, - "get_features": mock_feature, - "config": mock_config, - } - - -def test_get_tenant_info_should_add_next_credit_reset_date_in_cloud_edition( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """next_credit_reset_date should be present in CLOUD edition.""" - from services.workspace_service import WorkspaceService - - # Arrange - mocker.patch( - CREDIT_POOL_SERVICE_PATH, - side_effect=[None, None], # both paid and trial pools absent - ) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["next_credit_reset_date"] == "2025-02-01" - - -def test_get_tenant_info_should_use_paid_pool_when_plan_is_not_sandbox_and_pool_not_full( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """trial_credits/trial_credits_used come from the paid pool when conditions are met.""" - from services.workspace_service import WorkspaceService - - # Arrange - paid_pool = _make_pool(quota_limit=1000, quota_used=200) - mocker.patch(CREDIT_POOL_SERVICE_PATH, return_value=paid_pool) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["trial_credits"] == 1000 - assert result["trial_credits_used"] == 200 - - -def test_get_tenant_info_should_use_paid_pool_when_quota_limit_is_infinite( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """quota_limit == -1 means unlimited; service should still use the paid pool.""" - from services.workspace_service import WorkspaceService - - # Arrange - paid_pool = _make_pool(quota_limit=-1, quota_used=999) - mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, None]) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["trial_credits"] == -1 - assert result["trial_credits_used"] == 999 - - -def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_full( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """When paid pool is exhausted (used >= limit), switch to trial pool.""" - from services.workspace_service import WorkspaceService - - # Arrange - paid_pool = _make_pool(quota_limit=500, quota_used=500) # exactly full - trial_pool = _make_pool(quota_limit=100, quota_used=10) - mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool]) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["trial_credits"] == 100 - assert result["trial_credits_used"] == 10 - - -def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_none( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """When paid_pool is None, fall back to trial pool.""" - from services.workspace_service import WorkspaceService - - # Arrange - trial_pool = _make_pool(quota_limit=50, quota_used=5) - mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, trial_pool]) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["trial_credits"] == 50 - assert result["trial_credits_used"] == 5 - - -def test_get_tenant_info_should_fall_back_to_trial_pool_for_sandbox_plan( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """ - When the subscription plan IS SANDBOX, the paid pool branch is skipped - entirely and we fall back to the trial pool. - """ - from enums.cloud_plan import CloudPlan - from services.workspace_service import WorkspaceService - - # Arrange — override billing plan to SANDBOX - cloud_mocks["get_features"].return_value = _make_feature( - next_credit_reset_date="2025-02-01", - billing_plan=CloudPlan.SANDBOX, - ) - paid_pool = _make_pool(quota_limit=1000, quota_used=0) - trial_pool = _make_pool(quota_limit=200, quota_used=20) - mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool]) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["trial_credits"] == 200 - assert result["trial_credits_used"] == 20 - - -def test_get_tenant_info_should_omit_trial_credits_when_both_pools_are_none( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """When both paid and trial pools are absent, trial_credits should not be set.""" - from services.workspace_service import WorkspaceService - - # Arrange - mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, None]) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert "trial_credits" not in result - assert "trial_credits_used" not in result - - -# --------------------------------------------------------------------------- -# 5. Self-hosted / Non-Cloud Edition -# --------------------------------------------------------------------------- - - -def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode.""" - from services.workspace_service import WorkspaceService - - # Arrange (basic_mocks already sets EDITION = "SELF_HOSTED") - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert "next_credit_reset_date" not in result - assert "trial_credits" not in result - assert "trial_credits_used" not in result - - -# --------------------------------------------------------------------------- -# 6. DB query integrity -# --------------------------------------------------------------------------- - - -def test_get_tenant_info_should_query_tenant_account_join_with_correct_ids( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """ - The DB query for TenantAccountJoin must be scoped to the correct - tenant_id and current_user.id. - """ - from services.workspace_service import WorkspaceService - - # Arrange - tenant = _make_tenant(tenant_id="my-special-tenant") - mock_current_user = mocker.patch(CURRENT_USER_PATH) - mock_current_user.id = "special-user-id" - - # Act - WorkspaceService.get_tenant_info(tenant) - - # Assert — db.session.query was invoked (at least once) - basic_mocks["db_session"].query.assert_called() diff --git a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py index 2db83576b0..8525672da8 100644 --- a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py +++ b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py @@ -4,13 +4,12 @@ import json from unittest.mock import Mock, patch import pytest +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segments import ObjectSegment, StringSegment +from graphon.variables.types import SegmentType from sqlalchemy import Engine from core.workflow.file_reference import build_file_reference -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File -from graphon.variables.segments import ObjectSegment, StringSegment -from graphon.variables.types import SegmentType from models.model import UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile from services.workflow_draft_variable_service import DraftVarLoader diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 6200c9f859..e7e72793a3 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -4,6 +4,10 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from sqlalchemy import Engine from sqlalchemy.orm import Session @@ -13,11 +17,6 @@ from core.workflow.variable_prefixes import ( ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) -from graphon.enums import BuiltinNodeTypes -from graphon.file.enums import FileTransferMethod, FileType -from graphon.file.models import File -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType from libs.uuid_utils import uuidv7 from models.account import Account from models.enums import DraftVariableType diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py index ce66b78b64..077a7c27a2 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py @@ -8,13 +8,13 @@ from datetime import UTC, datetime from threading import Event import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool from models.enums import CreatorUserRole from models.model import AppMode from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py index d7192994b2..98d057e41f 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -3,6 +3,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from sqlalchemy.orm import sessionmaker from core.workflow.human_input_compat import ( @@ -12,9 +15,6 @@ from core.workflow.human_input_compat import ( ExternalRecipient, MemberRecipient, ) -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.human_input.entities import HumanInputNodeData from services import workflow_service as workflow_service_module from services.workflow_service import WorkflowService diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py deleted file mode 100644 index 6b04a1bc09..0000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ /dev/null @@ -1,415 +0,0 @@ -from contextlib import nullcontext -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import FormInputType -from models.model import App -from models.workflow import Workflow -from services import workflow_service as workflow_service_module -from services.workflow_service import WorkflowService - - -class TestWorkflowService: - @pytest.fixture - def workflow_service(self): - mock_session_maker = MagicMock() - return WorkflowService(mock_session_maker) - - @pytest.fixture - def mock_app(self): - app = MagicMock(spec=App) - app.id = "app-id-1" - app.workflow_id = "workflow-id-1" - app.tenant_id = "tenant-id-1" - return app - - @pytest.fixture - def mock_workflows(self): - workflows = [] - for i in range(5): - workflow = MagicMock(spec=Workflow) - workflow.id = f"workflow-id-{i}" - workflow.app_id = "app-id-1" - workflow.created_at = f"2023-01-0{5 - i}" # Descending date order - workflow.created_by = "user-id-1" if i % 2 == 0 else "user-id-2" - workflow.marked_name = f"Workflow {i}" if i % 2 == 0 else "" - workflows.append(workflow) - return workflows - - @pytest.fixture - def dummy_session_cls(self): - class DummySession: - def __init__(self, *args, **kwargs): - self.commit = MagicMock() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def begin(self): - return nullcontext() - - return DummySession - - def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app): - mock_app.workflow_id = None - mock_session = MagicMock() - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None - ) - - assert workflows == [] - assert has_more is False - mock_session.scalars.assert_not_called() - - def test_get_all_published_workflow_basic(self, workflow_service, mock_app, mock_workflows): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - mock_scalar_result.all.return_value = mock_workflows[:3] - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None - ) - - assert workflows == mock_workflows[:3] - assert has_more is False - mock_session.scalars.assert_called_once() - - def test_get_all_published_workflow_pagination(self, workflow_service, mock_app, mock_workflows): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - # Return 4 items when limit is 3, which should indicate has_more=True - mock_scalar_result.all.return_value = mock_workflows[:4] - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None - ) - - # Should return only the first 3 items - assert len(workflows) == 3 - assert workflows == mock_workflows[:3] - assert has_more is True - - # Test page 2 - mock_scalar_result.all.return_value = mock_workflows[3:] - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=2, limit=3, user_id=None - ) - - assert len(workflows) == 2 - assert has_more is False - - def test_get_all_published_workflow_user_filter(self, workflow_service, mock_app, mock_workflows): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - # Filter workflows for user-id-1 - filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1"] - mock_scalar_result.all.return_value = filtered_workflows - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1" - ) - - assert workflows == filtered_workflows - assert has_more is False - mock_session.scalars.assert_called_once() - - # Verify that the select contains a user filter clause - args = mock_session.scalars.call_args[0][0] - assert "created_by" in str(args) - - def test_get_all_published_workflow_named_only(self, workflow_service, mock_app, mock_workflows): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - # Filter workflows that have a marked_name - named_workflows = [w for w in mock_workflows if w.marked_name] - mock_scalar_result.all.return_value = named_workflows - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None, named_only=True - ) - - assert workflows == named_workflows - assert has_more is False - mock_session.scalars.assert_called_once() - - # Verify that the select contains a named_only filter clause - args = mock_session.scalars.call_args[0][0] - assert "marked_name !=" in str(args) - - def test_get_all_published_workflow_combined_filters(self, workflow_service, mock_app, mock_workflows): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - # Combined filter: user-id-1 and has marked_name - filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1" and w.marked_name] - mock_scalar_result.all.return_value = filtered_workflows - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1", named_only=True - ) - - assert workflows == filtered_workflows - assert has_more is False - mock_session.scalars.assert_called_once() - - # Verify that both filters are applied - args = mock_session.scalars.call_args[0][0] - assert "created_by" in str(args) - assert "marked_name !=" in str(args) - - def test_get_all_published_workflow_empty_result(self, workflow_service, mock_app): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - mock_scalar_result.all.return_value = [] - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None - ) - - assert workflows == [] - assert has_more is False - mock_session.scalars.assert_called_once() - - def test_submit_human_input_form_preview_uses_rendered_content( - self, - workflow_service: WorkflowService, - monkeypatch: pytest.MonkeyPatch, - dummy_session_cls, - ) -> None: - service = workflow_service - node_data = HumanInputNodeData( - title="Human Input", - form_content="

{{#$output.name#}}

", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - user_actions=[UserAction(id="approve", title="Approve")], - ) - node = MagicMock() - node.node_data = node_data - node.render_form_content_before_submission.return_value = "

preview

" - node.render_form_content_with_outputs.return_value = "

rendered

" - - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] - service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] - - workflow = MagicMock() - node_config = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": BuiltinNodeTypes.HUMAN_INPUT}} - ) - workflow.get_node_config_by_id.return_value = node_config - workflow.get_enclosing_node_type_and_id.return_value = None - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - - saved_outputs: dict[str, object] = {} - - class DummySaver: - def __init__(self, *args, **kwargs): - pass - - def save(self, outputs, process_data): - saved_outputs.update(outputs) - - monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) - monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver) - monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) - - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") - account = SimpleNamespace(id="account-1") - - result = service.submit_human_input_form_preview( - app_model=app_model, - account=account, - node_id="node-1", - form_inputs={"name": "Ada", "extra": "ignored"}, - inputs={"#node-0.result#": "LLM output"}, - action="approve", - ) - - service._build_human_input_variable_pool.assert_called_once_with( - app_model=app_model, - workflow=workflow, - node_config=node_config, - manual_inputs={"#node-0.result#": "LLM output"}, - user_id="account-1", - ) - - node.render_form_content_with_outputs.assert_called_once() - called_args = node.render_form_content_with_outputs.call_args.args - assert called_args[0] == "

preview

" - assert called_args[2] == node_data.outputs_field_names() - rendered_outputs = called_args[1] - assert rendered_outputs["name"] == "Ada" - assert rendered_outputs["extra"] == "ignored" - assert "extra" in saved_outputs - assert "extra" in result - assert saved_outputs["name"] == "Ada" - assert result["name"] == "Ada" - assert result["__action_id"] == "approve" - assert "__rendered_content" in result - - def test_submit_human_input_form_preview_missing_inputs_message(self, workflow_service: WorkflowService) -> None: - service = workflow_service - node_data = HumanInputNodeData( - title="Human Input", - form_content="

{{#$output.name#}}

", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - user_actions=[UserAction(id="approve", title="Approve")], - ) - node = MagicMock() - node.node_data = node_data - node._render_form_content_before_submission.return_value = "

preview

" - node._render_form_content_with_outputs.return_value = "

rendered

" - - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] - service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] - - workflow = MagicMock() - workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": BuiltinNodeTypes.HUMAN_INPUT}} - ) - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") - account = SimpleNamespace(id="account-1") - with pytest.raises(ValueError) as exc_info: - service.submit_human_input_form_preview( - app_model=app_model, - account=account, - node_id="node-1", - form_inputs={}, - inputs={}, - action="approve", - ) - - assert "Missing required inputs" in str(exc_info.value) - - def test_run_draft_workflow_node_successful_behavior( - self, workflow_service, mock_app, monkeypatch, dummy_session_cls - ): - """Behavior: When a basic workflow node runs, it correctly sets up context, - executes the node, and saves outputs.""" - service = workflow_service - account = SimpleNamespace(id="account-1") - mock_workflow = MagicMock() - mock_workflow.id = "wf-1" - mock_workflow.tenant_id = "tenant-1" - mock_workflow.environment_variables = [] - mock_workflow.conversation_variables = [] - - # Mock node config - mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": BuiltinNodeTypes.LLM}} - ) - mock_workflow.get_enclosing_node_type_and_id.return_value = None - - # Mock class methods - monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock()) - monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock()) - - # Mock workflow entry execution - mock_node_exec = MagicMock() - mock_node_exec.id = "exec-1" - mock_node_exec.process_data = {} - mock_run = MagicMock() - monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", mock_run) - - # Mock execution handling - service._handle_single_step_result = MagicMock(return_value=mock_node_exec) - - # Mock repository - mock_repo = MagicMock() - mock_repo.get_execution_by_id.return_value = mock_node_exec - mock_repo_factory = MagicMock(return_value=mock_repo) - monkeypatch.setattr( - workflow_service_module.DifyCoreRepositoryFactory, - "create_workflow_node_execution_repository", - mock_repo_factory, - ) - service._node_execution_service_repo = mock_repo - - # Set up node execution service repo mock to return our exec node - mock_node_exec.load_full_outputs.return_value = {"output_var": "result_value"} - mock_node_exec.node_id = "node-1" - mock_node_exec.node_type = "llm" - - # Mock draft variable saver - mock_saver = MagicMock() - monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", MagicMock(return_value=mock_saver)) - - # Mock DB - monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) - - monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) - - # Act - result = service.run_draft_workflow_node( - app_model=mock_app, - draft_workflow=mock_workflow, - node_id="node-1", - user_inputs={"input_val": "test"}, - account=account, - ) - - # Assert - assert result == mock_node_exec - service._handle_single_step_result.assert_called_once() - mock_repo.save.assert_called_once_with(mock_node_exec) - mock_saver.save.assert_called_once_with(process_data={}, outputs={"output_var": "result_value"}) - - def test_run_draft_workflow_node_failure_behavior(self, workflow_service, mock_app, monkeypatch, dummy_session_cls): - """Behavior: If retrieving the saved execution fails, an appropriate error bubble matches expectations.""" - service = workflow_service - account = SimpleNamespace(id="account-1") - mock_workflow = MagicMock() - mock_workflow.tenant_id = "tenant-1" - mock_workflow.environment_variables = [] - mock_workflow.conversation_variables = [] - mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": BuiltinNodeTypes.LLM}} - ) - mock_workflow.get_enclosing_node_type_and_id.return_value = None - - monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock()) - monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock()) - monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", MagicMock()) - - mock_node_exec = MagicMock() - mock_node_exec.id = "exec-invalid" - service._handle_single_step_result = MagicMock(return_value=mock_node_exec) - - mock_repo = MagicMock() - mock_repo_factory = MagicMock(return_value=mock_repo) - monkeypatch.setattr( - workflow_service_module.DifyCoreRepositoryFactory, - "create_workflow_node_execution_repository", - mock_repo_factory, - ) - service._node_execution_service_repo = mock_repo - - # Simulate failure to retrieve the saved execution - mock_repo.get_execution_by_id.return_value = None - - monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) - - monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) - - # Act & Assert - with pytest.raises(ValueError, match="WorkflowNodeExecution with id exec-invalid not found after saving"): - service.run_draft_workflow_node( - app_model=mock_app, draft_workflow=mock_workflow, node_id="node-1", user_inputs={}, account=account - ) diff --git a/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py b/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py new file mode 100644 index 0000000000..b48c69a146 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py @@ -0,0 +1,69 @@ +"""Unit tests for enterprise telemetry Celery task.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from tasks.enterprise_telemetry_task import process_enterprise_telemetry + + +@pytest.fixture +def sample_envelope_json(): + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-123", + payload={"app_id": "app-123"}, + ) + return envelope.model_dump_json() + + +def test_process_enterprise_telemetry_success(sample_envelope_json): + with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class: + mock_handler = MagicMock() + mock_handler_class.return_value = mock_handler + + process_enterprise_telemetry(sample_envelope_json) + + mock_handler.handle.assert_called_once() + call_args = mock_handler.handle.call_args[0][0] + assert isinstance(call_args, TelemetryEnvelope) + assert call_args.case == TelemetryCase.APP_CREATED + assert call_args.tenant_id == "test-tenant" + assert call_args.event_id == "test-event-123" + + +def test_process_enterprise_telemetry_invalid_json(caplog): + invalid_json = "not valid json" + + process_enterprise_telemetry(invalid_json) + + assert "Failed to process enterprise telemetry envelope" in caplog.text + + +def test_process_enterprise_telemetry_handler_exception(sample_envelope_json, caplog): + with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class: + mock_handler = MagicMock() + mock_handler.handle.side_effect = Exception("Handler error") + mock_handler_class.return_value = mock_handler + + process_enterprise_telemetry(sample_envelope_json) + + assert "Failed to process enterprise telemetry envelope" in caplog.text + + +def test_process_enterprise_telemetry_validation_error(caplog): + invalid_envelope = json.dumps( + { + "case": "INVALID_CASE", + "tenant_id": "test-tenant", + "event_id": "test-event", + "payload": {}, + } + ) + + process_enterprise_telemetry(invalid_envelope) + + assert "Failed to process enterprise telemetry envelope" in caplog.text diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py index 591da56f49..7119217e94 100644 --- a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py +++ b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py @@ -5,8 +5,8 @@ from types import SimpleNamespace from typing import Any import pytest - from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus + from tasks import human_input_timeout_tasks as task_module diff --git a/api/tests/unit_tests/tools/test_mcp_tool.py b/api/tests/unit_tests/tools/test_mcp_tool.py index f31bf80046..68359ba078 100644 --- a/api/tests/unit_tests/tools/test_mcp_tool.py +++ b/api/tests/unit_tests/tools/test_mcp_tool.py @@ -3,6 +3,7 @@ from decimal import Decimal from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage from core.mcp.types import ( AudioContent, @@ -17,7 +18,6 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.mcp_tool.tool import MCPTool -from graphon.model_runtime.entities.llm_entities import LLMUsage def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool: diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index c166a946d9..ffa6833524 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -2,9 +2,6 @@ from decimal import Decimal from unittest.mock import MagicMock, patch import pytest - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, @@ -21,6 +18,9 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output + def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage: """Create a mock LLMUsage with all required fields""" diff --git a/api/tests/workflow_test_utils.py b/api/tests/workflow_test_utils.py index a29df0bb6b..d33ac2c710 100644 --- a/api/tests/workflow_test_utils.py +++ b/api/tests/workflow_test_utils.py @@ -1,12 +1,13 @@ from collections.abc import Mapping from typing import Any -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool -from graphon.entities.graph_init_params import GraphInitParams +from graphon.entities import GraphInitParams from graphon.runtime import VariablePool from graphon.variables.variables import Variable +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool + def build_test_run_context( *, diff --git a/api/uv.lock b/api/uv.lock index fb08594fb3..c4cf31e3f5 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -204,7 +204,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/9a/7d/b22cb9a0d4f396ee0 [[package]] name = "alibabacloud-tea-openapi" -version = "0.4.3" +version = "0.4.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alibabacloud-credentials" }, @@ -213,9 +213,9 @@ dependencies = [ { name = "cryptography" }, { name = "darabonba-core" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/91/4f/b5288eea8f4d4b032c9a8f2cd1d926d5017977d10b874956f31e5343f299/alibabacloud_tea_openapi-0.4.3.tar.gz", hash = "sha256:12aef036ed993637b6f141abbd1de9d6199d5516f4a901588bb65d6a3768d41b", size = 21864, upload-time = "2026-01-15T07:55:16.744Z" } +sdist = { url = "https://files.pythonhosted.org/packages/30/93/138bcdc8fc596add73e37cf2073798f285284d1240bda9ee02f9384fc6be/alibabacloud_tea_openapi-0.4.4.tar.gz", hash = "sha256:1b0917bc03cd49417da64945e92731716d53e2eb8707b235f54e45b7473221ce", size = 21960, upload-time = "2026-03-26T10:16:16.792Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a5/37/48ee5468ecad19c6d44cf3b9629d77078e836ee3ec760f0366247f307b7c/alibabacloud_tea_openapi-0.4.3-py3-none-any.whl", hash = "sha256:d0b3a373b760ef6278b25fc128c73284301e07888977bf97519e7636d47bdf0a", size = 26159, upload-time = "2026-01-15T07:55:15.72Z" }, + { url = "https://files.pythonhosted.org/packages/f5/5a/6bfc4506438c1809c486f66217ad11eab78157192b3d5707b4e2f4212f6c/alibabacloud_tea_openapi-0.4.4-py3-none-any.whl", hash = "sha256:cea6bc1fe35b0319a8752cb99eb0ecb0dab7ca1a71b99c12970ba0867410995f", size = 26236, upload-time = "2026-03-26T10:16:15.861Z" }, ] [[package]] @@ -494,28 +494,29 @@ wheels = [ [[package]] name = "basedpyright" -version = "1.38.3" +version = "1.38.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodejs-wheel-binaries" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0f/58/7abba2c743571a42b2548f07aee556ebc1e4d0bc2b277aeba1ee6c83b0af/basedpyright-1.38.3.tar.gz", hash = "sha256:9725419786afbfad8a9539527f162da02d462afad440b0412fdb3f3cdf179b90", size = 25277430, upload-time = "2026-03-17T13:10:41.526Z" } +sdist = { url = "https://files.pythonhosted.org/packages/08/b4/26cb812eaf8ab56909c792c005fe1690706aef6f21d61107639e46e9c54c/basedpyright-1.38.4.tar.gz", hash = "sha256:8e7d4f37ffb6106621e06b9355025009cdf5b48f71c592432dd2dd304bf55e70", size = 25354730, upload-time = "2026-03-25T13:50:44.353Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/e3/3ebb5c23bd3abb5fc2053b8a06a889aa5c1cf8cff738c78cb6c1957e90cd/basedpyright-1.38.3-py3-none-any.whl", hash = "sha256:1f15c2e489c67d6c5e896c24b6a63251195c04223a55e4568b8f8e8ed49ca830", size = 12313363, upload-time = "2026-03-17T13:10:47.344Z" }, + { url = "https://files.pythonhosted.org/packages/62/0b/3f95fd47def42479e61077523d3752086d5c12009192a7f1c9fd5507e687/basedpyright-1.38.4-py3-none-any.whl", hash = "sha256:90aa067cf3e8a3c17ad5836a72b9e1f046bc72a4ad57d928473d9368c9cd07a2", size = 12352258, upload-time = "2026-03-25T13:50:41.059Z" }, ] [[package]] name = "bce-python-sdk" -version = "0.9.64" +version = "0.9.67" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "crc32c" }, { name = "future" }, { name = "pycryptodome" }, { name = "six" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/61/33/047e9c1a6c97e0cd4d93a6490abd8fbc2ccd13569462fc0228699edc08bc/bce_python_sdk-0.9.64.tar.gz", hash = "sha256:901bf787c26ad35855a80d65e58d7584c8541f7f0f2af20847830e572e5b622e", size = 287125, upload-time = "2026-03-17T11:24:29.345Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/b9/5140cc02832fe3a7394c52949796d43f8c1f635aa016100f857f504e0348/bce_python_sdk-0.9.67.tar.gz", hash = "sha256:2c673d757c5c8952f1be6611da4ab77a63ecabaa3ff22b11531f46845ac99e58", size = 295251, upload-time = "2026-03-24T14:10:07.086Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/48/7f/dd289582f37ab4effea47b2a8503880db4781ca0fc8e0a8ed5ff493359e5/bce_python_sdk-0.9.64-py3-none-any.whl", hash = "sha256:eaad97e4f0e7d613ae978da3cdc5294e9f724ffca2735f79820037fa1317cd6d", size = 402233, upload-time = "2026-03-17T11:24:24.673Z" }, + { url = "https://files.pythonhosted.org/packages/d4/a9/a58a63e2756e5d01901595af58c673f68de7621f28d71007479e00f45a6c/bce_python_sdk-0.9.67-py3-none-any.whl", hash = "sha256:3054879d098a92ceeb4b9ac1e64d2c658120a5a10e8e630f22410564b2170bf0", size = 410854, upload-time = "2026-03-24T14:09:54.29Z" }, ] [[package]] @@ -630,30 +631,30 @@ wheels = [ [[package]] name = "boto3" -version = "1.42.73" +version = "1.42.78" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore" }, { name = "jmespath" }, { name = "s3transfer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e4/8b/d00575be514744ca4839e7d85bf4a8a3c7b6b4574433291e58d14c68ae09/boto3-1.42.73.tar.gz", hash = "sha256:d37b58d6cd452ca808dd6823ae19ca65b6244096c5125ef9052988b337298bae", size = 112775, upload-time = "2026-03-20T19:39:52.814Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a8/2b/ebdad075934cf6bb78bf81fe31d83339bcd804ad6c856f7341376cbc88b6/boto3-1.42.78.tar.gz", hash = "sha256:cef2ebdb9be5c0e96822f8d3941ac4b816c90a5737a7ffb901d664c808964b63", size = 112789, upload-time = "2026-03-27T19:28:07.58Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/aa/05/1fcf03d90abaa3d0b42a6bfd10231dd709493ecbacf794aa2eea5eae6841/boto3-1.42.73-py3-none-any.whl", hash = "sha256:1f81b79b873f130eeab14bb556417a7c66d38f3396b7f2fe3b958b3f9094f455", size = 140556, upload-time = "2026-03-20T19:39:50.298Z" }, + { url = "https://files.pythonhosted.org/packages/57/bb/1f6dade1f1e86858bef7bd332bc8106c445f2dbabec7b32ab5d7d118c9b6/boto3-1.42.78-py3-none-any.whl", hash = "sha256:480a34a077484a5ca60124dfd150ba3ea6517fc89963a679e45b30c6db614d26", size = 140556, upload-time = "2026-03-27T19:28:06.125Z" }, ] [[package]] name = "boto3-stubs" -version = "1.42.73" +version = "1.42.78" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore-stubs" }, { name = "types-s3transfer" }, { name = "typing-extensions", marker = "python_full_version < '3.12'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b9/c3/fcc47102c63278af25ad57c93d97dc393f4dbc54c0117a29c78f2b96ec1e/boto3_stubs-1.42.73.tar.gz", hash = "sha256:36f625769b5505c4bc627f16244b98de9e10dae3ac36f1aa0f0ebe2f201dc138", size = 101373, upload-time = "2026-03-20T19:59:51.463Z" } +sdist = { url = "https://files.pythonhosted.org/packages/03/16/4bdb3c1f69bf7b97dd8b22fe5b007e9da67ba3f00ed10e47146f5fd9d0ff/boto3_stubs-1.42.78.tar.gz", hash = "sha256:423335b8ce9a935e404054978589cdb98d9fa1d4bd46073d6821bf1c3fad8ca7", size = 101602, upload-time = "2026-03-27T19:35:51.149Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4b/57/d570ba61a2a0c7fe0c8667e41269a0480293cb53e1786d6661a2bd827fc5/boto3_stubs-1.42.73-py3-none-any.whl", hash = "sha256:bd658429069d8215247fc3abc003220cd875c24ab6eda7b3405090408afaacdf", size = 70009, upload-time = "2026-03-20T19:59:43.786Z" }, + { url = "https://files.pythonhosted.org/packages/22/d5/bdedd4951c795899ac5a1f0b88d81b9e2c6333cb87457f2edd11ef3b7b7b/boto3_stubs-1.42.78-py3-none-any.whl", hash = "sha256:6ed07e734174751da8d01031d9ede8d81a88e4338d9e6b00ce7a6bc870075372", size = 70161, upload-time = "2026-03-27T19:35:46.336Z" }, ] [package.optional-dependencies] @@ -663,16 +664,16 @@ bedrock-runtime = [ [[package]] name = "botocore" -version = "1.42.73" +version = "1.42.78" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jmespath" }, { name = "python-dateutil" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/28/23/0c88ca116ef63b1ae77c901cd5d2095d22a8dbde9e80df74545db4a061b4/botocore-1.42.73.tar.gz", hash = "sha256:575858641e4949aaf2af1ced145b8524529edf006d075877af6b82ff96ad854c", size = 15008008, upload-time = "2026-03-20T19:39:40.082Z" } +sdist = { url = "https://files.pythonhosted.org/packages/67/8e/cdb34c8ca71216d214e049ada2148ee08bcda12b1ac72af3a720dea300ff/botocore-1.42.78.tar.gz", hash = "sha256:61cbd49728e23f68cfd945406ab40044d49abed143362f7ffa4a4f4bd4311791", size = 15023592, upload-time = "2026-03-27T19:27:57.122Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8e/65/971f3d55015f4d133a6ff3ad74cd39f4b8dd8f53f7775a3c2ad378ea5145/botocore-1.42.73-py3-none-any.whl", hash = "sha256:7b62e2a12f7a1b08eb7360eecd23bb16fe3b7ab7f5617cf91b25476c6f86a0fe", size = 14681861, upload-time = "2026-03-20T19:39:35.341Z" }, + { url = "https://files.pythonhosted.org/packages/54/72/94bba1a375d45c685b00e051b56142359547837086a83861d76f6aec26f4/botocore-1.42.78-py3-none-any.whl", hash = "sha256:038ab63c7f898e8b5db58cb6a45e4da56c31dd984e7e995839a3540c735564ea", size = 14701729, upload-time = "2026-03-27T19:27:54.05Z" }, ] [[package]] @@ -1062,7 +1063,7 @@ wheels = [ [[package]] name = "clickhouse-connect" -version = "0.14.1" +version = "0.15.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -1071,24 +1072,24 @@ dependencies = [ { name = "urllib3" }, { name = "zstandard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f5/0e/96958db88b6ce6e9d96dc7a836f12c7644934b3a436b04843f19eb8da2db/clickhouse_connect-0.14.1.tar.gz", hash = "sha256:dc107ae9ab7b86409049ae8abe21817543284b438291796d3dd639ad5496a1ab", size = 120093, upload-time = "2026-03-12T15:51:03.606Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/59/c0b0a2c2e4c204e5baeca4917a95cc95add651da3cec86ec464a8e54cfa0/clickhouse_connect-0.15.0.tar.gz", hash = "sha256:529fcf072df335d18ae16339d99389190f4bd543067dcdc174541c7a9c622ef5", size = 126344, upload-time = "2026-03-26T18:34:52.316Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/66/b0/04bc82ca70d4dcc35987c83e4ef04f6dec3c29d3cce4cda3523ebf4498dc/clickhouse_connect-0.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f2b1d1acb8f64c3cd9d922d9e8c0b6328238c4a38e084598c86cc95a0edbd8bd", size = 278797, upload-time = "2026-03-12T15:49:34.728Z" }, - { url = "https://files.pythonhosted.org/packages/97/03/f8434ed43946dcab2d8b4ccf8e90b1c6d69abea0fa8b8aaddb1dc9931657/clickhouse_connect-0.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:573f3e5a6b49135b711c086050f46510d4738cc09e5a354cc18ef26f8de5cd98", size = 271849, upload-time = "2026-03-12T15:49:35.881Z" }, - { url = "https://files.pythonhosted.org/packages/a0/db/b3665f4d855c780be8d00638d874fc0d62613d1f1c06ffcad7c11a333f06/clickhouse_connect-0.14.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:86b28932faab182a312779e5c3cf341abe19d31028a399bda9d8b06b3b9adab4", size = 1090975, upload-time = "2026-03-12T15:49:37.064Z" }, - { url = "https://files.pythonhosted.org/packages/ea/a2/7ba2d9669c5771734573397b034169653cdf3348dc4cc66bd66d8ab18910/clickhouse_connect-0.14.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bfc9650906ff96452c2b5676a7e68e8a77a5642504596f8482e0f3c0ccdffbf1", size = 1095899, upload-time = "2026-03-12T15:49:38.36Z" }, - { url = "https://files.pythonhosted.org/packages/e2/f4/0394af37b491ca832610f2ca7a129e85d8d857d40c94a42f2c2e6d3d9481/clickhouse_connect-0.14.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b379749a962599f9d6ec81e773a3b907ac58b001f4a977e4ac397f6a76fedff2", size = 1077567, upload-time = "2026-03-12T15:49:40.027Z" }, - { url = "https://files.pythonhosted.org/packages/9a/b8/9279a88afac94c262b55cc75aadc6a3e83f7fa1641e618f9060d9d38415f/clickhouse_connect-0.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:43ccb5debd13d41b97af81940c0cac01e92d39f17131d984591bedee13439a5d", size = 1100264, upload-time = "2026-03-12T15:49:41.414Z" }, - { url = "https://files.pythonhosted.org/packages/19/36/20e19ab392c211b83c967e275eb46f663853e0b8ce4da89056fda8a35fc6/clickhouse_connect-0.14.1-cp311-cp311-win32.whl", hash = "sha256:13cbe46c04be8e49da4f6aed698f2570a5295d15f498dd5511b4f761d1ef0edc", size = 250488, upload-time = "2026-03-12T15:49:42.649Z" }, - { url = "https://files.pythonhosted.org/packages/9d/3b/74a07e692a21cad4692e72595cdefbd709bd74a9f778c7334d57a98ee548/clickhouse_connect-0.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:7038cf547c542a17a465e062cd837659f46f99c991efcb010a9ea08ce70960ab", size = 268730, upload-time = "2026-03-12T15:49:44.225Z" }, - { url = "https://files.pythonhosted.org/packages/58/9e/d84a14241967b3aa1e657bbbee83e2eee02d3d6df1ebe8edd4ed72cd8643/clickhouse_connect-0.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:97665169090889a8bc4dbae4a5fc758b91a23e49a8f8ddc1ae993f18f6d71e02", size = 280679, upload-time = "2026-03-12T15:49:45.497Z" }, - { url = "https://files.pythonhosted.org/packages/d8/29/80835a980be6298a7a2ae42d5a14aab0c9c066ecafe1763bc1958a6f6f0f/clickhouse_connect-0.14.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3ee6b513ca7d83e0f7b46d87bc2e48260316431cb466680e3540400379bcd1db", size = 271570, upload-time = "2026-03-12T15:49:46.721Z" }, - { url = "https://files.pythonhosted.org/packages/8b/bf/25c17cb91d72143742d2b060c6954e8000a7753c1fd21f7bf8b49ef2bd89/clickhouse_connect-0.14.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2a0e8a3f46aba99f1c574927d196e12f1ee689e31c41bf0caec86ad3e181abf3", size = 1115637, upload-time = "2026-03-12T15:49:47.921Z" }, - { url = "https://files.pythonhosted.org/packages/2d/5f/5d5df3585d98889aedc55c9eeb2ea90dba27ec4329eee392101619daf0c0/clickhouse_connect-0.14.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:25698cddcdd6c2e4ea12dc5c56d6035d77fc99c5d75e96a54123826c36fdd8ae", size = 1131995, upload-time = "2026-03-12T15:49:49.791Z" }, - { url = "https://files.pythonhosted.org/packages/ad/50/acc9f4c6a1d712f2ed11626f8451eff222e841cf0809655362f0e90454b6/clickhouse_connect-0.14.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:29ab49e5cac44b830b58de73d17a7d895f6c362bf67a50134ff405b428774f44", size = 1095380, upload-time = "2026-03-12T15:49:51.388Z" }, - { url = "https://files.pythonhosted.org/packages/08/18/1ef01beee93d243ec9d9c37f0ce62b3083478a5dd7f59cc13279600cd3a5/clickhouse_connect-0.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3cbf7d7a134692bacd68dd5f8661e87f5db94af60db9f3a74bd732596794910a", size = 1127217, upload-time = "2026-03-12T15:49:53.016Z" }, - { url = "https://files.pythonhosted.org/packages/18/e2/b4daee8287dc49eb9918c77b1e57f5644e47008f719b77281bf5fca63f6e/clickhouse_connect-0.14.1-cp312-cp312-win32.whl", hash = "sha256:6f295b66f3e2ed931dd0d3bb80e00ee94c6f4a584b2dc6d998872b2e0ceaa706", size = 250775, upload-time = "2026-03-12T15:49:54.639Z" }, - { url = "https://files.pythonhosted.org/packages/01/c7/7b55d346952fcd8f0f491faca4449f607a04764fd23cada846dc93facb9e/clickhouse_connect-0.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:c6bb2cce37041c90f8a3b1b380665acbaf252f125e401c13ce8f8df105378f69", size = 269353, upload-time = "2026-03-12T15:49:55.854Z" }, + { url = "https://files.pythonhosted.org/packages/83/b0/bf4a169a1b4e5e19f5e884596937ce13855146a3f4b3225228a87701fd18/clickhouse_connect-0.15.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4f0928fdfb408d314c0e5151caf30b1c3bd56c2812ffdbc8d262fb60c0e7ab28", size = 284805, upload-time = "2026-03-26T18:33:18.659Z" }, + { url = "https://files.pythonhosted.org/packages/ec/d5/63dd572db91bd5e1231d7b7dc63591c52ffbbf653a57f9b8449681815976/clickhouse_connect-0.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6486b02825ac87f57811710e5a9a2da8531bb3c88bcb154fd5c7378742a33d66", size = 277846, upload-time = "2026-03-26T18:33:20.171Z" }, + { url = "https://files.pythonhosted.org/packages/e4/d6/192130a807de130945cc451e17c89ac6183625b8028026e5a4a7fc46fa59/clickhouse_connect-0.15.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f2df9c2fd97b40c6493232e0cbf516d8ba268165c6161851ef15f4f1fd0456e", size = 1096969, upload-time = "2026-03-26T18:33:21.728Z" }, + { url = "https://files.pythonhosted.org/packages/32/46/f2895cc4240ef45a2a274d4323f6858c0860034efe6c9a1c7168f1d8cecd/clickhouse_connect-0.15.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a5a349d19c63abb49c884afe0a0387823045831f005451e85c09c032f953f1c1", size = 1101890, upload-time = "2026-03-26T18:33:23.038Z" }, + { url = "https://files.pythonhosted.org/packages/e8/69/dcecbca254b45525ad3fd8294441ac9cf8a8a8bd1fa8fd6b93e241b377a3/clickhouse_connect-0.15.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4d80205cbdbface6d2f35fbd65a6f85caf2b59ec65f2e9dd190f11e335fe7316", size = 1083561, upload-time = "2026-03-26T18:33:24.64Z" }, + { url = "https://files.pythonhosted.org/packages/69/10/21f0cb98453d9710aaeb92f9a9e156e909c1ac72e57210a48b0f615916a7/clickhouse_connect-0.15.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c3c84dfebf49ec7a2cd9ac31c46986f7a81b43ea781d23ef7d607907fcc6de5d", size = 1106257, upload-time = "2026-03-26T18:33:26.257Z" }, + { url = "https://files.pythonhosted.org/packages/70/91/ae0f5c8df5dc650f1ab327d4b40cde7e18bf9e8b3507764dce320c328092/clickhouse_connect-0.15.0-cp311-cp311-win32.whl", hash = "sha256:d2bbdccf9cd838b990576d3f7d1e6a0ab5c3a5c8eb830394258b7b225531fe74", size = 256591, upload-time = "2026-03-26T18:33:27.869Z" }, + { url = "https://files.pythonhosted.org/packages/e6/7f/85673ff522554ef76e17b5d267816c199a731fde836ef957b0960655f251/clickhouse_connect-0.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:1c4223d557bc0a3919cb7ce0d749d9091123b6e61341e028ffc09b7f9c847ac2", size = 274778, upload-time = "2026-03-26T18:33:29.02Z" }, + { url = "https://files.pythonhosted.org/packages/f5/be/86e149c60822caed29e4435acac4fc73e20fddfb0b56ea6452bc7a08ab10/clickhouse_connect-0.15.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d51f49694e9007564bfd8dac51a1f9e60b94d6c93a07eb4027113a2e62bbb384", size = 286680, upload-time = "2026-03-26T18:33:30.219Z" }, + { url = "https://files.pythonhosted.org/packages/aa/65/c38cc5028afa2ccd9e8ff65611434063c0c5c1b6edadc507dbbc80a09bfd/clickhouse_connect-0.15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6a48fbad9ebc2b6d1cd01d1f9b5d6740081f1c84f1aacc9f91651be949f6b6ed", size = 277579, upload-time = "2026-03-26T18:33:31.474Z" }, + { url = "https://files.pythonhosted.org/packages/0a/ef/c8b2ef597fefd04e8b7c017c991552162cb89b7cb73bfdd6225b1c79e2fe/clickhouse_connect-0.15.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36e1ae470b94cc56d270461c8626c8fd4dac16e6c1ffa8477f21c012462e22cf", size = 1121630, upload-time = "2026-03-26T18:33:32.983Z" }, + { url = "https://files.pythonhosted.org/packages/de/f7/1b71819e825d44582c014a489618170b03ccdac3c9b710dfd56445f1c017/clickhouse_connect-0.15.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fa97f0ae8eb069a451d8577342dffeef5dc308a0eac7dba1809008c761e720c7", size = 1137988, upload-time = "2026-03-26T18:33:34.585Z" }, + { url = "https://files.pythonhosted.org/packages/7f/1f/41002b8d5ff146dc2835dc6b6f690bc361bd9a94b6195872abcb922f3788/clickhouse_connect-0.15.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b5b3baf70009174a4df9c8356c96d03e1c2dbf0d8b29f1b3270a641a59399b61", size = 1101376, upload-time = "2026-03-26T18:33:36.258Z" }, + { url = "https://files.pythonhosted.org/packages/2c/8a/bd090dab73fc9c47efcaaeb152a77610b9d233cd88ea73cf4535f9bac2a6/clickhouse_connect-0.15.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:af3fba93fd2efa8f856f3a88a6a710e06005fa48b6b6b0f116d462a4021957e2", size = 1133211, upload-time = "2026-03-26T18:33:38.003Z" }, + { url = "https://files.pythonhosted.org/packages/f1/8d/cf4eee7225bdee85a9b8a88c5bfff42ce48f37ee9277930ac8bc76f47126/clickhouse_connect-0.15.0-cp312-cp312-win32.whl", hash = "sha256:86ca76f8acaf7f3f6530e3e4139e174d54c4674910c69f4277d1b9cdf7c1cc98", size = 256767, upload-time = "2026-03-26T18:33:39.55Z" }, + { url = "https://files.pythonhosted.org/packages/26/6e/f5a2cb1e4624dfd77c1e226239360a9e3690db8056a0027bda2ab87d0085/clickhouse_connect-0.15.0-cp312-cp312-win_amd64.whl", hash = "sha256:5a471d9a9cf06f0a4e90784547b6a2acb066b0d8642dfea9866960c4bdde6959", size = 275404, upload-time = "2026-03-26T18:33:40.885Z" }, ] [[package]] @@ -1308,43 +1309,47 @@ wheels = [ [[package]] name = "cryptography" -version = "44.0.3" +version = "46.0.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/53/d6/1411ab4d6108ab167d06254c5be517681f1e331f90edf1379895bcb87020/cryptography-44.0.3.tar.gz", hash = "sha256:fe19d8bc5536a91a24a8133328880a41831b6c5df54599a8417b62fe015d3053", size = 711096, upload-time = "2025-05-02T19:36:04.667Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a4/ba/04b1bd4218cbc58dc90ce967106d51582371b898690f3ae0402876cc4f34/cryptography-46.0.6.tar.gz", hash = "sha256:27550628a518c5c6c903d84f637fbecf287f6cb9ced3804838a1295dc1fd0759", size = 750542, upload-time = "2026-03-25T23:34:53.396Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/08/53/c776d80e9d26441bb3868457909b4e74dd9ccabd182e10b2b0ae7a07e265/cryptography-44.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:962bc30480a08d133e631e8dfd4783ab71cc9e33d5d7c1e192f0b7c06397bb88", size = 6670281, upload-time = "2025-05-02T19:34:50.665Z" }, - { url = "https://files.pythonhosted.org/packages/6a/06/af2cf8d56ef87c77319e9086601bef621bedf40f6f59069e1b6d1ec498c5/cryptography-44.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ffc61e8f3bf5b60346d89cd3d37231019c17a081208dfbbd6e1605ba03fa137", size = 3959305, upload-time = "2025-05-02T19:34:53.042Z" }, - { url = "https://files.pythonhosted.org/packages/ae/01/80de3bec64627207d030f47bf3536889efee8913cd363e78ca9a09b13c8e/cryptography-44.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58968d331425a6f9eedcee087f77fd3c927c88f55368f43ff7e0a19891f2642c", size = 4171040, upload-time = "2025-05-02T19:34:54.675Z" }, - { url = "https://files.pythonhosted.org/packages/bd/48/bb16b7541d207a19d9ae8b541c70037a05e473ddc72ccb1386524d4f023c/cryptography-44.0.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:e28d62e59a4dbd1d22e747f57d4f00c459af22181f0b2f787ea83f5a876d7c76", size = 3963411, upload-time = "2025-05-02T19:34:56.61Z" }, - { url = "https://files.pythonhosted.org/packages/42/b2/7d31f2af5591d217d71d37d044ef5412945a8a8e98d5a2a8ae4fd9cd4489/cryptography-44.0.3-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:af653022a0c25ef2e3ffb2c673a50e5a0d02fecc41608f4954176f1933b12359", size = 3689263, upload-time = "2025-05-02T19:34:58.591Z" }, - { url = "https://files.pythonhosted.org/packages/25/50/c0dfb9d87ae88ccc01aad8eb93e23cfbcea6a6a106a9b63a7b14c1f93c75/cryptography-44.0.3-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:157f1f3b8d941c2bd8f3ffee0af9b049c9665c39d3da9db2dc338feca5e98a43", size = 4196198, upload-time = "2025-05-02T19:35:00.988Z" }, - { url = "https://files.pythonhosted.org/packages/66/c9/55c6b8794a74da652690c898cb43906310a3e4e4f6ee0b5f8b3b3e70c441/cryptography-44.0.3-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:c6cd67722619e4d55fdb42ead64ed8843d64638e9c07f4011163e46bc512cf01", size = 3966502, upload-time = "2025-05-02T19:35:03.091Z" }, - { url = "https://files.pythonhosted.org/packages/b6/f7/7cb5488c682ca59a02a32ec5f975074084db4c983f849d47b7b67cc8697a/cryptography-44.0.3-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:b424563394c369a804ecbee9b06dfb34997f19d00b3518e39f83a5642618397d", size = 4196173, upload-time = "2025-05-02T19:35:05.018Z" }, - { url = "https://files.pythonhosted.org/packages/d2/0b/2f789a8403ae089b0b121f8f54f4a3e5228df756e2146efdf4a09a3d5083/cryptography-44.0.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c91fc8e8fd78af553f98bc7f2a1d8db977334e4eea302a4bfd75b9461c2d8904", size = 4087713, upload-time = "2025-05-02T19:35:07.187Z" }, - { url = "https://files.pythonhosted.org/packages/1d/aa/330c13655f1af398fc154089295cf259252f0ba5df93b4bc9d9c7d7f843e/cryptography-44.0.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:25cd194c39fa5a0aa4169125ee27d1172097857b27109a45fadc59653ec06f44", size = 4299064, upload-time = "2025-05-02T19:35:08.879Z" }, - { url = "https://files.pythonhosted.org/packages/10/a8/8c540a421b44fd267a7d58a1fd5f072a552d72204a3f08194f98889de76d/cryptography-44.0.3-cp37-abi3-win32.whl", hash = "sha256:3be3f649d91cb182c3a6bd336de8b61a0a71965bd13d1a04a0e15b39c3d5809d", size = 2773887, upload-time = "2025-05-02T19:35:10.41Z" }, - { url = "https://files.pythonhosted.org/packages/b9/0d/c4b1657c39ead18d76bbd122da86bd95bdc4095413460d09544000a17d56/cryptography-44.0.3-cp37-abi3-win_amd64.whl", hash = "sha256:3883076d5c4cc56dbef0b898a74eb6992fdac29a7b9013870b34efe4ddb39a0d", size = 3209737, upload-time = "2025-05-02T19:35:12.12Z" }, - { url = "https://files.pythonhosted.org/packages/34/a3/ad08e0bcc34ad436013458d7528e83ac29910943cea42ad7dd4141a27bbb/cryptography-44.0.3-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:5639c2b16764c6f76eedf722dbad9a0914960d3489c0cc38694ddf9464f1bb2f", size = 6673501, upload-time = "2025-05-02T19:35:13.775Z" }, - { url = "https://files.pythonhosted.org/packages/b1/f0/7491d44bba8d28b464a5bc8cc709f25a51e3eac54c0a4444cf2473a57c37/cryptography-44.0.3-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3ffef566ac88f75967d7abd852ed5f182da252d23fac11b4766da3957766759", size = 3960307, upload-time = "2025-05-02T19:35:15.917Z" }, - { url = "https://files.pythonhosted.org/packages/f7/c8/e5c5d0e1364d3346a5747cdcd7ecbb23ca87e6dea4f942a44e88be349f06/cryptography-44.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:192ed30fac1728f7587c6f4613c29c584abdc565d7417c13904708db10206645", size = 4170876, upload-time = "2025-05-02T19:35:18.138Z" }, - { url = "https://files.pythonhosted.org/packages/73/96/025cb26fc351d8c7d3a1c44e20cf9a01e9f7cf740353c9c7a17072e4b264/cryptography-44.0.3-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:7d5fe7195c27c32a64955740b949070f21cba664604291c298518d2e255931d2", size = 3964127, upload-time = "2025-05-02T19:35:19.864Z" }, - { url = "https://files.pythonhosted.org/packages/01/44/eb6522db7d9f84e8833ba3bf63313f8e257729cf3a8917379473fcfd6601/cryptography-44.0.3-cp39-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3f07943aa4d7dad689e3bb1638ddc4944cc5e0921e3c227486daae0e31a05e54", size = 3689164, upload-time = "2025-05-02T19:35:21.449Z" }, - { url = "https://files.pythonhosted.org/packages/68/fb/d61a4defd0d6cee20b1b8a1ea8f5e25007e26aeb413ca53835f0cae2bcd1/cryptography-44.0.3-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:cb90f60e03d563ca2445099edf605c16ed1d5b15182d21831f58460c48bffb93", size = 4198081, upload-time = "2025-05-02T19:35:23.187Z" }, - { url = "https://files.pythonhosted.org/packages/1b/50/457f6911d36432a8811c3ab8bd5a6090e8d18ce655c22820994913dd06ea/cryptography-44.0.3-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:ab0b005721cc0039e885ac3503825661bd9810b15d4f374e473f8c89b7d5460c", size = 3967716, upload-time = "2025-05-02T19:35:25.426Z" }, - { url = "https://files.pythonhosted.org/packages/35/6e/dca39d553075980ccb631955c47b93d87d27f3596da8d48b1ae81463d915/cryptography-44.0.3-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:3bb0847e6363c037df8f6ede57d88eaf3410ca2267fb12275370a76f85786a6f", size = 4197398, upload-time = "2025-05-02T19:35:27.678Z" }, - { url = "https://files.pythonhosted.org/packages/9b/9d/d1f2fe681eabc682067c66a74addd46c887ebacf39038ba01f8860338d3d/cryptography-44.0.3-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:b0cc66c74c797e1db750aaa842ad5b8b78e14805a9b5d1348dc603612d3e3ff5", size = 4087900, upload-time = "2025-05-02T19:35:29.312Z" }, - { url = "https://files.pythonhosted.org/packages/c4/f5/3599e48c5464580b73b236aafb20973b953cd2e7b44c7c2533de1d888446/cryptography-44.0.3-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6866df152b581f9429020320e5eb9794c8780e90f7ccb021940d7f50ee00ae0b", size = 4301067, upload-time = "2025-05-02T19:35:31.547Z" }, - { url = "https://files.pythonhosted.org/packages/a7/6c/d2c48c8137eb39d0c193274db5c04a75dab20d2f7c3f81a7dcc3a8897701/cryptography-44.0.3-cp39-abi3-win32.whl", hash = "sha256:c138abae3a12a94c75c10499f1cbae81294a6f983b3af066390adee73f433028", size = 2775467, upload-time = "2025-05-02T19:35:33.805Z" }, - { url = "https://files.pythonhosted.org/packages/c9/ad/51f212198681ea7b0deaaf8846ee10af99fba4e894f67b353524eab2bbe5/cryptography-44.0.3-cp39-abi3-win_amd64.whl", hash = "sha256:5d186f32e52e66994dce4f766884bcb9c68b8da62d61d9d215bfe5fb56d21334", size = 3210375, upload-time = "2025-05-02T19:35:35.369Z" }, - { url = "https://files.pythonhosted.org/packages/8d/4b/c11ad0b6c061902de5223892d680e89c06c7c4d606305eb8de56c5427ae6/cryptography-44.0.3-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:896530bc9107b226f265effa7ef3f21270f18a2026bc09fed1ebd7b66ddf6375", size = 3390230, upload-time = "2025-05-02T19:35:49.062Z" }, - { url = "https://files.pythonhosted.org/packages/58/11/0a6bf45d53b9b2290ea3cec30e78b78e6ca29dc101e2e296872a0ffe1335/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:9b4d4a5dbee05a2c390bf212e78b99434efec37b17a4bff42f50285c5c8c9647", size = 3895216, upload-time = "2025-05-02T19:35:51.351Z" }, - { url = "https://files.pythonhosted.org/packages/0a/27/b28cdeb7270e957f0077a2c2bfad1b38f72f1f6d699679f97b816ca33642/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:02f55fb4f8b79c1221b0961488eaae21015b69b210e18c386b69de182ebb1259", size = 4115044, upload-time = "2025-05-02T19:35:53.044Z" }, - { url = "https://files.pythonhosted.org/packages/35/b0/ec4082d3793f03cb248881fecefc26015813199b88f33e3e990a43f79835/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:dd3db61b8fe5be220eee484a17233287d0be6932d056cf5738225b9c05ef4fff", size = 3898034, upload-time = "2025-05-02T19:35:54.72Z" }, - { url = "https://files.pythonhosted.org/packages/0b/7f/adf62e0b8e8d04d50c9a91282a57628c00c54d4ae75e2b02a223bd1f2613/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:978631ec51a6bbc0b7e58f23b68a8ce9e5f09721940933e9c217068388789fe5", size = 4114449, upload-time = "2025-05-02T19:35:57.139Z" }, - { url = "https://files.pythonhosted.org/packages/87/62/d69eb4a8ee231f4bf733a92caf9da13f1c81a44e874b1d4080c25ecbb723/cryptography-44.0.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:5d20cc348cca3a8aa7312f42ab953a56e15323800ca3ab0706b8cd452a3a056c", size = 3134369, upload-time = "2025-05-02T19:35:58.907Z" }, + { url = "https://files.pythonhosted.org/packages/47/23/9285e15e3bc57325b0a72e592921983a701efc1ee8f91c06c5f0235d86d9/cryptography-46.0.6-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:64235194bad039a10bb6d2d930ab3323baaec67e2ce36215fd0952fad0930ca8", size = 7176401, upload-time = "2026-03-25T23:33:22.096Z" }, + { url = "https://files.pythonhosted.org/packages/60/f8/e61f8f13950ab6195b31913b42d39f0f9afc7d93f76710f299b5ec286ae6/cryptography-46.0.6-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:26031f1e5ca62fcb9d1fcb34b2b60b390d1aacaa15dc8b895a9ed00968b97b30", size = 4275275, upload-time = "2026-03-25T23:33:23.844Z" }, + { url = "https://files.pythonhosted.org/packages/19/69/732a736d12c2631e140be2348b4ad3d226302df63ef64d30dfdb8db7ad1c/cryptography-46.0.6-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9a693028b9cbe51b5a1136232ee8f2bc242e4e19d456ded3fa7c86e43c713b4a", size = 4425320, upload-time = "2026-03-25T23:33:25.703Z" }, + { url = "https://files.pythonhosted.org/packages/d4/12/123be7292674abf76b21ac1fc0e1af50661f0e5b8f0ec8285faac18eb99e/cryptography-46.0.6-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:67177e8a9f421aa2d3a170c3e56eca4e0128883cf52a071a7cbf53297f18b175", size = 4278082, upload-time = "2026-03-25T23:33:27.423Z" }, + { url = "https://files.pythonhosted.org/packages/5b/ba/d5e27f8d68c24951b0a484924a84c7cdaed7502bac9f18601cd357f8b1d2/cryptography-46.0.6-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:d9528b535a6c4f8ff37847144b8986a9a143585f0540fbcb1a98115b543aa463", size = 4926514, upload-time = "2026-03-25T23:33:29.206Z" }, + { url = "https://files.pythonhosted.org/packages/34/71/1ea5a7352ae516d5512d17babe7e1b87d9db5150b21f794b1377eac1edc0/cryptography-46.0.6-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:22259338084d6ae497a19bae5d4c66b7ca1387d3264d1c2c0e72d9e9b6a77b97", size = 4457766, upload-time = "2026-03-25T23:33:30.834Z" }, + { url = "https://files.pythonhosted.org/packages/01/59/562be1e653accee4fdad92c7a2e88fced26b3fdfce144047519bbebc299e/cryptography-46.0.6-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:760997a4b950ff00d418398ad73fbc91aa2894b5c1db7ccb45b4f68b42a63b3c", size = 3986535, upload-time = "2026-03-25T23:33:33.02Z" }, + { url = "https://files.pythonhosted.org/packages/d6/8b/b1ebfeb788bf4624d36e45ed2662b8bd43a05ff62157093c1539c1288a18/cryptography-46.0.6-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:3dfa6567f2e9e4c5dceb8ccb5a708158a2a871052fa75c8b78cb0977063f1507", size = 4277618, upload-time = "2026-03-25T23:33:34.567Z" }, + { url = "https://files.pythonhosted.org/packages/dd/52/a005f8eabdb28df57c20f84c44d397a755782d6ff6d455f05baa2785bd91/cryptography-46.0.6-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:cdcd3edcbc5d55757e5f5f3d330dd00007ae463a7e7aa5bf132d1f22a4b62b19", size = 4890802, upload-time = "2026-03-25T23:33:37.034Z" }, + { url = "https://files.pythonhosted.org/packages/ec/4d/8e7d7245c79c617d08724e2efa397737715ca0ec830ecb3c91e547302555/cryptography-46.0.6-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:d4e4aadb7fc1f88687f47ca20bb7227981b03afaae69287029da08096853b738", size = 4457425, upload-time = "2026-03-25T23:33:38.904Z" }, + { url = "https://files.pythonhosted.org/packages/1d/5c/f6c3596a1430cec6f949085f0e1a970638d76f81c3ea56d93d564d04c340/cryptography-46.0.6-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2b417edbe8877cda9022dde3a008e2deb50be9c407eef034aeeb3a8b11d9db3c", size = 4405530, upload-time = "2026-03-25T23:33:40.842Z" }, + { url = "https://files.pythonhosted.org/packages/7e/c9/9f9cea13ee2dbde070424e0c4f621c091a91ffcc504ffea5e74f0e1daeff/cryptography-46.0.6-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:380343e0653b1c9d7e1f55b52aaa2dbb2fdf2730088d48c43ca1c7c0abb7cc2f", size = 4667896, upload-time = "2026-03-25T23:33:42.781Z" }, + { url = "https://files.pythonhosted.org/packages/ad/b5/1895bc0821226f129bc74d00eccfc6a5969e2028f8617c09790bf89c185e/cryptography-46.0.6-cp311-abi3-win32.whl", hash = "sha256:bcb87663e1f7b075e48c3be3ecb5f0b46c8fc50b50a97cf264e7f60242dca3f2", size = 3026348, upload-time = "2026-03-25T23:33:45.021Z" }, + { url = "https://files.pythonhosted.org/packages/c3/f8/c9bcbf0d3e6ad288b9d9aa0b1dee04b063d19e8c4f871855a03ab3a297ab/cryptography-46.0.6-cp311-abi3-win_amd64.whl", hash = "sha256:6739d56300662c468fddb0e5e291f9b4d084bead381667b9e654c7dd81705124", size = 3483896, upload-time = "2026-03-25T23:33:46.649Z" }, + { url = "https://files.pythonhosted.org/packages/c4/cc/f330e982852403da79008552de9906804568ae9230da8432f7496ce02b71/cryptography-46.0.6-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:12cae594e9473bca1a7aceb90536060643128bb274fcea0fc459ab90f7d1ae7a", size = 7162776, upload-time = "2026-03-25T23:34:13.308Z" }, + { url = "https://files.pythonhosted.org/packages/49/b3/dc27efd8dcc4bff583b3f01d4a3943cd8b5821777a58b3a6a5f054d61b79/cryptography-46.0.6-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:639301950939d844a9e1c4464d7e07f902fe9a7f6b215bb0d4f28584729935d8", size = 4270529, upload-time = "2026-03-25T23:34:15.019Z" }, + { url = "https://files.pythonhosted.org/packages/e6/05/e8d0e6eb4f0d83365b3cb0e00eb3c484f7348db0266652ccd84632a3d58d/cryptography-46.0.6-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ed3775295fb91f70b4027aeba878d79b3e55c0b3e97eaa4de71f8f23a9f2eb77", size = 4414827, upload-time = "2026-03-25T23:34:16.604Z" }, + { url = "https://files.pythonhosted.org/packages/2f/97/daba0f5d2dc6d855e2dcb70733c812558a7977a55dd4a6722756628c44d1/cryptography-46.0.6-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:8927ccfbe967c7df312ade694f987e7e9e22b2425976ddbf28271d7e58845290", size = 4271265, upload-time = "2026-03-25T23:34:18.586Z" }, + { url = "https://files.pythonhosted.org/packages/89/06/fe1fce39a37ac452e58d04b43b0855261dac320a2ebf8f5260dd55b201a9/cryptography-46.0.6-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:b12c6b1e1651e42ab5de8b1e00dc3b6354fdfd778e7fa60541ddacc27cd21410", size = 4916800, upload-time = "2026-03-25T23:34:20.561Z" }, + { url = "https://files.pythonhosted.org/packages/ff/8a/b14f3101fe9c3592603339eb5d94046c3ce5f7fc76d6512a2d40efd9724e/cryptography-46.0.6-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:063b67749f338ca9c5a0b7fe438a52c25f9526b851e24e6c9310e7195aad3b4d", size = 4448771, upload-time = "2026-03-25T23:34:22.406Z" }, + { url = "https://files.pythonhosted.org/packages/01/b3/0796998056a66d1973fd52ee89dc1bb3b6581960a91ad4ac705f182d398f/cryptography-46.0.6-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:02fad249cb0e090b574e30b276a3da6a149e04ee2f049725b1f69e7b8351ec70", size = 3978333, upload-time = "2026-03-25T23:34:24.281Z" }, + { url = "https://files.pythonhosted.org/packages/c5/3d/db200af5a4ffd08918cd55c08399dc6c9c50b0bc72c00a3246e099d3a849/cryptography-46.0.6-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:7e6142674f2a9291463e5e150090b95a8519b2fb6e6aaec8917dd8d094ce750d", size = 4271069, upload-time = "2026-03-25T23:34:25.895Z" }, + { url = "https://files.pythonhosted.org/packages/d7/18/61acfd5b414309d74ee838be321c636fe71815436f53c9f0334bf19064fa/cryptography-46.0.6-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:456b3215172aeefb9284550b162801d62f5f264a081049a3e94307fe20792cfa", size = 4878358, upload-time = "2026-03-25T23:34:27.67Z" }, + { url = "https://files.pythonhosted.org/packages/8b/65/5bf43286d566f8171917cae23ac6add941654ccf085d739195a4eacf1674/cryptography-46.0.6-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:341359d6c9e68834e204ceaf25936dffeafea3829ab80e9503860dcc4f4dac58", size = 4448061, upload-time = "2026-03-25T23:34:29.375Z" }, + { url = "https://files.pythonhosted.org/packages/e0/25/7e49c0fa7205cf3597e525d156a6bce5b5c9de1fd7e8cb01120e459f205a/cryptography-46.0.6-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9a9c42a2723999a710445bc0d974e345c32adfd8d2fac6d8a251fa829ad31cfb", size = 4399103, upload-time = "2026-03-25T23:34:32.036Z" }, + { url = "https://files.pythonhosted.org/packages/44/46/466269e833f1c4718d6cd496ffe20c56c9c8d013486ff66b4f69c302a68d/cryptography-46.0.6-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6617f67b1606dfd9fe4dbfa354a9508d4a6d37afe30306fe6c101b7ce3274b72", size = 4659255, upload-time = "2026-03-25T23:34:33.679Z" }, + { url = "https://files.pythonhosted.org/packages/0a/09/ddc5f630cc32287d2c953fc5d32705e63ec73e37308e5120955316f53827/cryptography-46.0.6-cp38-abi3-win32.whl", hash = "sha256:7f6690b6c55e9c5332c0b59b9c8a3fb232ebf059094c17f9019a51e9827df91c", size = 3010660, upload-time = "2026-03-25T23:34:35.418Z" }, + { url = "https://files.pythonhosted.org/packages/1b/82/ca4893968aeb2709aacfb57a30dec6fa2ab25b10fa9f064b8882ce33f599/cryptography-46.0.6-cp38-abi3-win_amd64.whl", hash = "sha256:79e865c642cfc5c0b3eb12af83c35c5aeff4fa5c672dc28c43721c2c9fdd2f0f", size = 3471160, upload-time = "2026-03-25T23:34:37.191Z" }, + { url = "https://files.pythonhosted.org/packages/2e/84/7ccff00ced5bac74b775ce0beb7d1be4e8637536b522b5df9b73ada42da2/cryptography-46.0.6-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:2ea0f37e9a9cf0df2952893ad145fd9627d326a59daec9b0802480fa3bcd2ead", size = 3475444, upload-time = "2026-03-25T23:34:38.944Z" }, + { url = "https://files.pythonhosted.org/packages/bc/1f/4c926f50df7749f000f20eede0c896769509895e2648db5da0ed55db711d/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a3e84d5ec9ba01f8fd03802b2147ba77f0c8f2617b2aff254cedd551844209c8", size = 4218227, upload-time = "2026-03-25T23:34:40.871Z" }, + { url = "https://files.pythonhosted.org/packages/c6/65/707be3ffbd5f786028665c3223e86e11c4cda86023adbc56bd72b1b6bab5/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:12f0fa16cc247b13c43d56d7b35287ff1569b5b1f4c5e87e92cc4fcc00cd10c0", size = 4381399, upload-time = "2026-03-25T23:34:42.609Z" }, + { url = "https://files.pythonhosted.org/packages/f3/6d/73557ed0ef7d73d04d9aba745d2c8e95218213687ee5e76b7d236a5030fc/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:50575a76e2951fe7dbd1f56d181f8c5ceeeb075e9ff88e7ad997d2f42af06e7b", size = 4217595, upload-time = "2026-03-25T23:34:44.205Z" }, + { url = "https://files.pythonhosted.org/packages/9e/c5/e1594c4eec66a567c3ac4400008108a415808be2ce13dcb9a9045c92f1a0/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:90e5f0a7b3be5f40c3a0a0eafb32c681d8d2c181fc2a1bdabe9b3f611d9f6b1a", size = 4380912, upload-time = "2026-03-25T23:34:46.328Z" }, + { url = "https://files.pythonhosted.org/packages/1a/89/843b53614b47f97fe1abc13f9a86efa5ec9e275292c457af1d4a60dc80e0/cryptography-46.0.6-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:6728c49e3b2c180ef26f8e9f0a883a2c585638db64cf265b49c9ba10652d430e", size = 3409955, upload-time = "2026-03-25T23:34:48.465Z" }, ] [[package]] @@ -1489,12 +1494,12 @@ dependencies = [ { name = "google-auth-httplib2" }, { name = "google-cloud-aiplatform" }, { name = "googleapis-common-protos" }, + { name = "graphon" }, { name = "gunicorn" }, { name = "httpx", extra = ["socks"] }, { name = "httpx-sse" }, { name = "jieba" }, { name = "json-repair" }, - { name = "jsonschema" }, { name = "langfuse" }, { name = "langsmith" }, { name = "litellm" }, @@ -1526,7 +1531,6 @@ dependencies = [ { name = "psycopg2-binary" }, { name = "pycryptodome" }, { name = "pydantic" }, - { name = "pydantic-extra-types" }, { name = "pydantic-settings" }, { name = "pyjwt" }, { name = "pypandoc" }, @@ -1547,7 +1551,6 @@ dependencies = [ { name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] }, { name = "weave" }, { name = "weaviate-client" }, - { name = "webvtt-py" }, { name = "yarl" }, ] @@ -1590,7 +1593,6 @@ dev = [ { name = "types-greenlet" }, { name = "types-html5lib" }, { name = "types-jmespath" }, - { name = "types-jsonschema" }, { name = "types-markdown" }, { name = "types-oauthlib" }, { name = "types-objgraph" }, @@ -1669,7 +1671,7 @@ requires-dist = [ { name = "azure-identity", specifier = "==1.25.3" }, { name = "beautifulsoup4", specifier = "==4.14.3" }, { name = "bleach", specifier = "~=6.3.0" }, - { name = "boto3", specifier = "==1.42.73" }, + { name = "boto3", specifier = "==1.42.78" }, { name = "bs4", specifier = "~=0.0.1" }, { name = "cachetools", specifier = "~=5.3.0" }, { name = "celery", specifier = "~=5.6.2" }, @@ -1692,12 +1694,12 @@ requires-dist = [ { name = "google-auth-httplib2", specifier = "==0.3.0" }, { name = "google-cloud-aiplatform", specifier = ">=1.123.0" }, { name = "googleapis-common-protos", specifier = ">=1.65.0" }, - { name = "gunicorn", specifier = "~=25.1.0" }, + { name = "graphon", specifier = ">=0.1.2" }, + { name = "gunicorn", specifier = "~=25.3.0" }, { name = "httpx", extras = ["socks"], specifier = "~=0.28.0" }, { name = "httpx-sse", specifier = "~=0.4.0" }, { name = "jieba", specifier = "==0.42.1" }, { name = "json-repair", specifier = ">=0.55.1" }, - { name = "jsonschema", specifier = ">=4.25.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.7.16" }, { name = "litellm", specifier = "==1.82.6" }, @@ -1729,7 +1731,6 @@ requires-dist = [ { name = "psycopg2-binary", specifier = "~=2.9.6" }, { name = "pycryptodome", specifier = "==3.23.0" }, { name = "pydantic", specifier = "~=2.12.5" }, - { name = "pydantic-extra-types", specifier = "~=2.11.0" }, { name = "pydantic-settings", specifier = "~=2.13.1" }, { name = "pyjwt", specifier = "~=2.12.0" }, { name = "pypandoc", specifier = "~=1.13" }, @@ -1738,7 +1739,7 @@ requires-dist = [ { name = "python-dotenv", specifier = "==1.2.2" }, { name = "pyyaml", specifier = "~=6.0.1" }, { name = "readabilipy", specifier = "~=0.3.0" }, - { name = "redis", extras = ["hiredis"], specifier = "~=7.3.0" }, + { name = "redis", extras = ["hiredis"], specifier = "~=7.4.0" }, { name = "resend", specifier = "~=2.26.0" }, { name = "sendgrid", specifier = "~=6.12.3" }, { name = "sentry-sdk", extras = ["flask"], specifier = "~=2.55.0" }, @@ -1750,7 +1751,6 @@ requires-dist = [ { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.21.5" }, { name = "weave", specifier = ">=0.52.16" }, { name = "weaviate-client", specifier = "==4.20.4" }, - { name = "webvtt-py", specifier = "~=0.5.1" }, { name = "yarl", specifier = "~=1.23.0" }, ] @@ -1793,7 +1793,6 @@ dev = [ { name = "types-greenlet", specifier = "~=3.3.0" }, { name = "types-html5lib", specifier = "~=1.1.11" }, { name = "types-jmespath", specifier = ">=1.0.2.20240106" }, - { name = "types-jsonschema", specifier = "~=4.26.0" }, { name = "types-markdown", specifier = "~=3.10.2" }, { name = "types-oauthlib", specifier = "~=3.3.0" }, { name = "types-objgraph", specifier = "~=3.6.0" }, @@ -1811,7 +1810,7 @@ dev = [ { name = "types-pywin32", specifier = "~=311.0.0" }, { name = "types-pyyaml", specifier = "~=6.0.12" }, { name = "types-redis", specifier = ">=4.6.0.20241004" }, - { name = "types-regex", specifier = "~=2026.2.28" }, + { name = "types-regex", specifier = "~=2026.3.32" }, { name = "types-setuptools", specifier = ">=80.9.0" }, { name = "types-shapely", specifier = "~=2.1.0" }, { name = "types-simplejson", specifier = ">=3.20.0" }, @@ -1839,7 +1838,7 @@ vdb = [ { name = "alibabacloud-gpdb20160503", specifier = "~=5.1.0" }, { name = "alibabacloud-tea-openapi", specifier = "~=0.4.3" }, { name = "chromadb", specifier = "==0.5.20" }, - { name = "clickhouse-connect", specifier = "~=0.14.1" }, + { name = "clickhouse-connect", specifier = "~=0.15.0" }, { name = "clickzetta-connector-python", specifier = ">=0.8.102" }, { name = "couchbase", specifier = "~=4.5.0" }, { name = "elasticsearch", specifier = "==8.14.0" }, @@ -1855,13 +1854,13 @@ vdb = [ { name = "pymochow", specifier = "==2.3.6" }, { name = "pyobvector", specifier = "~=0.2.17" }, { name = "qdrant-client", specifier = "==1.9.0" }, - { name = "tablestore", specifier = "==6.4.1" }, - { name = "tcvectordb", specifier = "~=2.0.0" }, + { name = "tablestore", specifier = "==6.4.2" }, + { name = "tcvectordb", specifier = "~=2.1.0" }, { name = "tidb-vector", specifier = "==0.0.15" }, { name = "upstash-vector", specifier = "==0.8.0" }, { name = "volcengine-compat", specifier = "~=1.0.0" }, { name = "weaviate-client", specifier = "==4.20.4" }, - { name = "xinference-client", specifier = "~=2.3.1" }, + { name = "xinference-client", specifier = "~=2.4.0" }, ] [[package]] @@ -2024,14 +2023,14 @@ wheels = [ [[package]] name = "faker" -version = "40.11.0" +version = "40.11.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "tzdata", marker = "sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/94/dc/b68e5378e5a7db0ab776efcdd53b6fe374b29d703e156fd5bb4c5437069e/faker-40.11.0.tar.gz", hash = "sha256:7c419299103b13126bd02ec14bd2b47b946edb5a5eedf305e66a193b25f9a734", size = 1957570, upload-time = "2026-03-13T14:36:11.844Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/e5/b16bf568a2f20fe7423282db4a4059dbcadef70e9029c1c106836f8edd84/faker-40.11.1.tar.gz", hash = "sha256:61965046e79e8cfde4337d243eac04c0d31481a7c010033141103b43f603100c", size = 1957415, upload-time = "2026-03-23T14:05:50.233Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/fa/a86c6ba66f0308c95b9288b1e3eaccd934b545646f63494a86f1ec2f8c8e/faker-40.11.0-py3-none-any.whl", hash = "sha256:0e9816c950528d2a37d74863f3ef389ea9a3a936cbcde0b11b8499942e25bf90", size = 1989457, upload-time = "2026-03-13T14:36:09.792Z" }, + { url = "https://files.pythonhosted.org/packages/fc/ec/3c4b78eb0d2f6a81fb8cc9286745845bff661e6815741eff7a6ac5fcc9ea/faker-40.11.1-py3-none-any.whl", hash = "sha256:3af3a213ba8fb33ce6ba2af7aef2ac91363dae35d0cec0b2b0337d189e5bee2a", size = 1989484, upload-time = "2026-03-23T14:05:48.793Z" }, ] [[package]] @@ -2472,7 +2471,7 @@ wheels = [ [[package]] name = "google-cloud-aiplatform" -version = "1.142.0" +version = "1.143.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docstring-parser" }, @@ -2488,9 +2487,9 @@ dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/41/0d/3063a0512d60cf18854a279e00ccb796429545464345ef821cf77cb93d05/google_cloud_aiplatform-1.142.0.tar.gz", hash = "sha256:87b49e002703dc14885093e9b264587db84222bef5f70f5a442d03f41beecdd1", size = 10207993, upload-time = "2026-03-20T22:49:13.797Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/08/939fb05870fdf155410a927e22f5b053d49f18e215618e102fba1d8bb147/google_cloud_aiplatform-1.143.0.tar.gz", hash = "sha256:1f0124a89795a6b473deb28724dd37d95334205df3a9c9c48d0b8d7a3d5d5cc4", size = 10215389, upload-time = "2026-03-25T18:30:15.444Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/59/8b/f29646d3fa940f0e38cfcc12137f4851856b50d7486a3c05103ebc78d82d/google_cloud_aiplatform-1.142.0-py2.py3-none-any.whl", hash = "sha256:17c91db9b613cbbafb2c36335b123686aeb2b4b8448be5134b565ae07165a39a", size = 8388991, upload-time = "2026-03-20T22:49:10.334Z" }, + { url = "https://files.pythonhosted.org/packages/90/14/16323e604e79dc63b528268f97a841c2c29dd8eb16395de6bf530c1a5ebe/google_cloud_aiplatform-1.143.0-py2.py3-none-any.whl", hash = "sha256:78df97d044859f743a9cc48b89a260d33579b0d548b1589bb3ae9f4c2afc0c5a", size = 8392705, upload-time = "2026-03-25T18:30:11.496Z" }, ] [[package]] @@ -2543,7 +2542,7 @@ wheels = [ [[package]] name = "google-cloud-storage" -version = "3.10.0" +version = "3.10.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -2553,9 +2552,9 @@ dependencies = [ { name = "google-resumable-media" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7a/e3/747759eebc72e420c25903d6bc231d0ceb110b66ac7e6ee3f350417152cd/google_cloud_storage-3.10.0.tar.gz", hash = "sha256:1aeebf097c27d718d84077059a28d7e87f136f3700212215f1ceeae1d1c5d504", size = 17309829, upload-time = "2026-03-18T15:54:11.875Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/47/205eb8e9a1739b5345843e5a425775cbdc472cc38e7eda082ba5b8d02450/google_cloud_storage-3.10.1.tar.gz", hash = "sha256:97db9aa4460727982040edd2bd13ff3d5e2260b5331ad22895802da1fc2a5286", size = 17309950, upload-time = "2026-03-23T09:35:23.409Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/29/e2/d58442f4daee5babd9255cf492a1f3d114357164072f8339a22a3ad460a2/google_cloud_storage-3.10.0-py3-none-any.whl", hash = "sha256:0072e7783b201e45af78fd9779894cdb6bec2bf922ee932f3fcc16f8bce9b9a3", size = 324382, upload-time = "2026-03-18T15:54:10.091Z" }, + { url = "https://files.pythonhosted.org/packages/ad/ff/ca9ab2417fa913d75aae38bf40bf856bb2749a604b2e0f701b37cfcd23cc/google_cloud_storage-3.10.1-py3-none-any.whl", hash = "sha256:a72f656759b7b99bda700f901adcb3425a828d4a29f911bc26b3ea79c5b1217f", size = 324453, upload-time = "2026-03-23T09:35:21.368Z" }, ] [[package]] @@ -2613,14 +2612,14 @@ wheels = [ [[package]] name = "googleapis-common-protos" -version = "1.73.0" +version = "1.73.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/99/96/a0205167fa0154f4a542fd6925bdc63d039d88dab3588b875078107e6f06/googleapis_common_protos-1.73.0.tar.gz", hash = "sha256:778d07cd4fbeff84c6f7c72102f0daf98fa2bfd3fa8bea426edc545588da0b5a", size = 147323, upload-time = "2026-03-06T21:53:09.727Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/c0/4a54c386282c13449eca8bbe2ddb518181dc113e78d240458a68856b4d69/googleapis_common_protos-1.73.1.tar.gz", hash = "sha256:13114f0e9d2391756a0194c3a8131974ed7bffb06086569ba193364af59163b6", size = 147506, upload-time = "2026-03-26T22:17:38.451Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/69/28/23eea8acd65972bbfe295ce3666b28ac510dfcb115fac089d3edb0feb00a/googleapis_common_protos-1.73.0-py3-none-any.whl", hash = "sha256:dfdaaa2e860f242046be561e6d6cb5c5f1541ae02cfbcb034371aadb2942b4e8", size = 297578, upload-time = "2026-03-06T21:52:33.933Z" }, + { url = "https://files.pythonhosted.org/packages/dc/82/fcb6520612bec0c39b973a6c0954b6a0d948aadfe8f7e9487f60ceb8bfa6/googleapis_common_protos-1.73.1-py3-none-any.whl", hash = "sha256:e51f09eb0a43a8602f5a915870972e6b4a394088415c79d79605a46d8e826ee8", size = 297556, upload-time = "2026-03-26T22:15:58.455Z" }, ] [package.optional-dependencies] @@ -2652,6 +2651,34 @@ requests = [ { name = "requests-toolbelt" }, ] +[[package]] +name = "graphon" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "charset-normalizer" }, + { name = "httpx" }, + { name = "json-repair" }, + { name = "jsonschema" }, + { name = "orjson" }, + { name = "pandas", extra = ["excel"] }, + { name = "pydantic" }, + { name = "pydantic-extra-types" }, + { name = "pypandoc" }, + { name = "pypdfium2" }, + { name = "python-docx" }, + { name = "pyyaml" }, + { name = "tiktoken" }, + { name = "transformers" }, + { name = "typing-extensions" }, + { name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] }, + { name = "webvtt-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/fc/0a5342a1c29bc367c2254c170ef130a84a60d8cd1c9cc84a7a85e96c1042/graphon-0.1.2.tar.gz", hash = "sha256:a2210629f93258ad2e7cbe85b5d4c6826814f6c679aa2a23ca100511363b9240", size = 214744, upload-time = "2026-03-27T20:09:53.802Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/46/65b5e366ec2d7017b6d6448e2635b3772d86840a6f7297277471b1bfbfbd/graphon-0.1.2-py3-none-any.whl", hash = "sha256:79f0c7796de7b8642d070730bb8bdaf1c68ccdfcecac38e0b2282e0543f0a6db", size = 314398, upload-time = "2026-03-27T20:09:52.524Z" }, +] + [[package]] name = "graphql-core" version = "3.2.7" @@ -2843,14 +2870,14 @@ wheels = [ [[package]] name = "gunicorn" -version = "25.1.0" +version = "25.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "packaging" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/66/13/ef67f59f6a7896fdc2c1d62b5665c5219d6b0a9a1784938eb9a28e55e128/gunicorn-25.1.0.tar.gz", hash = "sha256:1426611d959fa77e7de89f8c0f32eed6aa03ee735f98c01efba3e281b1c47616", size = 594377, upload-time = "2026-02-13T11:09:58.989Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c4/f4/e78fa054248fab913e2eab0332c6c2cb07421fca1ce56d8fe43b6aef57a4/gunicorn-25.3.0.tar.gz", hash = "sha256:f74e1b2f9f76f6cd1ca01198968bd2dd65830edc24b6e8e4d78de8320e2fe889", size = 634883, upload-time = "2026-03-27T00:00:26.092Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/da/73/4ad5b1f6a2e21cf1e85afdaad2b7b1a933985e2f5d679147a1953aaa192c/gunicorn-25.1.0-py3-none-any.whl", hash = "sha256:d0b1236ccf27f72cfe14bce7caadf467186f19e865094ca84221424e839b8b8b", size = 197067, upload-time = "2026-02-13T11:09:57.146Z" }, + { url = "https://files.pythonhosted.org/packages/43/c8/8aaf447698c4d59aa853fd318eed300b5c9e44459f242ab8ead6c9c09792/gunicorn-25.3.0-py3-none-any.whl", hash = "sha256:cacea387dab08cd6776501621c295a904fe8e3b7aae9a1a3cbb26f4e7ed54660", size = 208403, upload-time = "2026-03-27T00:00:27.386Z" }, ] [[package]] @@ -3083,14 +3110,14 @@ wheels = [ [[package]] name = "hypothesis" -version = "6.151.9" +version = "6.151.10" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "sortedcontainers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/19/e1/ef365ff480903b929d28e057f57b76cae51a30375943e33374ec9a165d9c/hypothesis-6.151.9.tar.gz", hash = "sha256:2f284428dda6c3c48c580de0e18470ff9c7f5ef628a647ee8002f38c3f9097ca", size = 463534, upload-time = "2026-02-16T22:59:23.09Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/dd/633e2cd62377333b7681628aee2ec1d88166f5bdf916b08c98b1e8288ad3/hypothesis-6.151.10.tar.gz", hash = "sha256:6c9565af8b4aa3a080b508f66ce9c2a77dd613c7e9073e27fc7e4ef9f45f8a27", size = 463762, upload-time = "2026-03-29T01:06:22.19Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c4/f7/5cc291d701094754a1d327b44d80a44971e13962881d9a400235726171da/hypothesis-6.151.9-py3-none-any.whl", hash = "sha256:7b7220585c67759b1b1ef839b1e6e9e3d82ed468cfc1ece43c67184848d7edd9", size = 529307, upload-time = "2026-02-16T22:59:20.443Z" }, + { url = "https://files.pythonhosted.org/packages/40/da/439bb2e451979f5e88c13bbebc3e9e17754429cfb528c93677b2bd81783b/hypothesis-6.151.10-py3-none-any.whl", hash = "sha256:b0d7728f0c8c2be009f89fcdd6066f70c5439aa0f94adbb06e98261d05f49b05", size = 529493, upload-time = "2026-03-29T01:06:19.161Z" }, ] [[package]] @@ -3905,7 +3932,7 @@ wheels = [ [[package]] name = "nltk" -version = "3.9.3" +version = "3.9.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -3913,9 +3940,9 @@ dependencies = [ { name = "regex" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e1/8f/915e1c12df07c70ed779d18ab83d065718a926e70d3ea33eb0cd66ffb7c0/nltk-3.9.3.tar.gz", hash = "sha256:cb5945d6424a98d694c2b9a0264519fab4363711065a46aa0ae7a2195b92e71f", size = 2923673, upload-time = "2026-02-24T12:05:53.833Z" } +sdist = { url = "https://files.pythonhosted.org/packages/74/a1/b3b4adf15585a5bc4c357adde150c01ebeeb642173ded4d871e89468767c/nltk-3.9.4.tar.gz", hash = "sha256:ed03bc098a40481310320808b2db712d95d13ca65b27372f8a403949c8b523d0", size = 2946864, upload-time = "2026-03-24T06:13:40.641Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/7e/9af5a710a1236e4772de8dfcc6af942a561327bb9f42b5b4a24d0cf100fd/nltk-3.9.3-py3-none-any.whl", hash = "sha256:60b3db6e9995b3dd976b1f0fa7dec22069b2677e759c28eb69b62ddd44870522", size = 1525385, upload-time = "2026-02-24T12:05:46.54Z" }, + { url = "https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl", hash = "sha256:f2fa301c3a12718ce4a0e9305c5675299da5ad9e26068218b69d692fda84828f", size = 1552087, upload-time = "2026-03-24T06:13:38.47Z" }, ] [[package]] @@ -4462,7 +4489,7 @@ wheels = [ [[package]] name = "opik" -version = "1.10.45" +version = "1.10.54" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "boto3-stubs", extra = ["bedrock-runtime"] }, @@ -4481,9 +4508,9 @@ dependencies = [ { name = "tqdm" }, { name = "uuid6" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/85/17/edea6308347cec62e6828de7c573c596559c502b54fa4f0c88a52e2e81f5/opik-1.10.45.tar.gz", hash = "sha256:d8d8627ba03d12def46965e03d58f611daaf5cf878b3d087c53fe1159788c140", size = 789876, upload-time = "2026-03-20T11:35:12.457Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/c9/ecc68c5ae32bf5b1074bdc713cb1543b8e2a46c58c814bf150fecf50f272/opik-1.10.54.tar.gz", hash = "sha256:46e29abf4656bd80b9cb339659d24ecf97b61f37c3fde594de75e5f59953e9d3", size = 812757, upload-time = "2026-03-27T11:23:06.109Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/17/150e9eecfa28cb23f7a0bfe83ae1486a11022b97fe6d12328b455784658d/opik-1.10.45-py3-none-any.whl", hash = "sha256:e8050d9e5e0d92ff587f156eacbdd02099897f39cfe79a98380b6c8ae9906b95", size = 1337714, upload-time = "2026-03-20T11:35:10.237Z" }, + { url = "https://files.pythonhosted.org/packages/58/91/1ae4e8a349da0620a6f0a4fc51cd00c3e75176939d022e8684379aee2928/opik-1.10.54-py3-none-any.whl", hash = "sha256:5f8ddabe5283ebe08d455e81b188d6e09ce1d1efa989f8b05567ef70f1e9aeda", size = 1379008, upload-time = "2026-03-27T11:23:04.582Z" }, ] [[package]] @@ -5249,7 +5276,7 @@ crypto = [ [[package]] name = "pymilvus" -version = "2.6.10" +version = "2.6.11" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cachetools" }, @@ -5261,9 +5288,9 @@ dependencies = [ { name = "requests" }, { name = "setuptools" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9e/85/90362066ccda5ff6fec693a55693cde659fdcd36d08f1bd7012ae958248d/pymilvus-2.6.10.tar.gz", hash = "sha256:58a44ee0f1dddd7727ae830ef25325872d8946f029d801a37105164e6699f1b8", size = 1561042, upload-time = "2026-03-13T09:54:22.441Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/e6/0adc3b374f5c5d1eebd4f551b455c6865c449b170b17545001b208e2b153/pymilvus-2.6.11.tar.gz", hash = "sha256:a40c10322cde25184a8c3d84993a14dfb67ad2bdcfc5dff7e68b11a79ff8f6d8", size = 1583634, upload-time = "2026-03-27T06:25:46.023Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/88/10/fe7fbb6795aa20038afd55e9c653991e7c69fb24c741ebb39ba3b0aa5c13/pymilvus-2.6.10-py3-none-any.whl", hash = "sha256:a048b6f3ebad93742bca559beabf44fe578f0983555a109c4436b5fb2c1dbd40", size = 312797, upload-time = "2026-03-13T09:54:21.081Z" }, + { url = "https://files.pythonhosted.org/packages/9c/1c/bccb331d71f824738f80f11e9b8b4da47973c903826355526ae4fa2b762f/pymilvus-2.6.11-py3-none-any.whl", hash = "sha256:a11e1718b15045361c71ca671b959900cb7e2faae863c896f6b7e87bf2e4d10a", size = 315252, upload-time = "2026-03-27T06:25:44.215Z" }, ] [[package]] @@ -5785,14 +5812,14 @@ wheels = [ [[package]] name = "redis" -version = "7.3.0" +version = "7.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "async-timeout", marker = "python_full_version < '3.11.3'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/da/82/4d1a5279f6c1251d3d2a603a798a1137c657de9b12cfc1fba4858232c4d2/redis-7.3.0.tar.gz", hash = "sha256:4d1b768aafcf41b01022410b3cc4f15a07d9b3d6fe0c66fc967da2c88e551034", size = 4928081, upload-time = "2026-03-06T18:18:16.287Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7b/7f/3759b1d0d72b7c92f0d70ffd9dc962b7b7b5ee74e135f9d7d8ab06b8a318/redis-7.4.0.tar.gz", hash = "sha256:64a6ea7bf567ad43c964d2c30d82853f8df927c5c9017766c55a1d1ed95d18ad", size = 4943913, upload-time = "2026-03-24T09:14:37.53Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/28/84e57fce7819e81ec5aa1bd31c42b89607241f4fb1a3ea5b0d2dbeaea26c/redis-7.3.0-py3-none-any.whl", hash = "sha256:9d4fcb002a12a5e3c3fbe005d59c48a2cc231f87fbb2f6b70c2d89bb64fec364", size = 404379, upload-time = "2026-03-06T18:18:14.583Z" }, + { url = "https://files.pythonhosted.org/packages/74/3a/95deec7db1eb53979973ebd156f3369a72732208d1391cd2e5d127062a32/redis-7.4.0-py3-none-any.whl", hash = "sha256:a9c74a5c893a5ef8455a5adb793a31bb70feb821c86eccb62eebef5a19c429ec", size = 409772, upload-time = "2026-03-24T09:14:35.968Z" }, ] [package.optional-dependencies] @@ -5852,7 +5879,7 @@ wheels = [ [[package]] name = "requests" -version = "2.32.5" +version = "2.33.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -5860,9 +5887,9 @@ dependencies = [ { name = "idna" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/64/8860370b167a9721e8956ae116825caff829224fbca0ca6e7bf8ddef8430/requests-2.33.0.tar.gz", hash = "sha256:c7ebc5e8b0f21837386ad0e1c8fe8b829fa5f544d8df3b2253bff14ef29d7652", size = 134232, upload-time = "2026-03-25T15:10:41.586Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, + { url = "https://files.pythonhosted.org/packages/56/5d/c814546c2333ceea4ba42262d8c4d55763003e767fa169adc693bd524478/requests-2.33.0-py3-none-any.whl", hash = "sha256:3324635456fa185245e24865e810cecec7b4caf933d7eb133dcde67d48cee69b", size = 65017, upload-time = "2026-03-25T15:10:40.382Z" }, ] [[package]] @@ -5981,27 +6008,27 @@ wheels = [ [[package]] name = "ruff" -version = "0.15.7" +version = "0.15.8" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a1/22/9e4f66ee588588dc6c9af6a994e12d26e19efbe874d1a909d09a6dac7a59/ruff-0.15.7.tar.gz", hash = "sha256:04f1ae61fc20fe0b148617c324d9d009b5f63412c0b16474f3d5f1a1a665f7ac", size = 4601277, upload-time = "2026-03-19T16:26:22.605Z" } +sdist = { url = "https://files.pythonhosted.org/packages/14/b0/73cf7550861e2b4824950b8b52eebdcc5adc792a00c514406556c5b80817/ruff-0.15.8.tar.gz", hash = "sha256:995f11f63597ee362130d1d5a327a87cb6f3f5eae3094c620bcc632329a4d26e", size = 4610921, upload-time = "2026-03-26T18:39:38.675Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/41/2f/0b08ced94412af091807b6119ca03755d651d3d93a242682bf020189db94/ruff-0.15.7-py3-none-linux_armv6l.whl", hash = "sha256:a81cc5b6910fb7dfc7c32d20652e50fa05963f6e13ead3c5915c41ac5d16668e", size = 10489037, upload-time = "2026-03-19T16:26:32.47Z" }, - { url = "https://files.pythonhosted.org/packages/91/4a/82e0fa632e5c8b1eba5ee86ecd929e8ff327bbdbfb3c6ac5d81631bef605/ruff-0.15.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:722d165bd52403f3bdabc0ce9e41fc47070ac56d7a91b4e0d097b516a53a3477", size = 10955433, upload-time = "2026-03-19T16:27:00.205Z" }, - { url = "https://files.pythonhosted.org/packages/ab/10/12586735d0ff42526ad78c049bf51d7428618c8b5c467e72508c694119df/ruff-0.15.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7fbc2448094262552146cbe1b9643a92f66559d3761f1ad0656d4991491af49e", size = 10269302, upload-time = "2026-03-19T16:26:26.183Z" }, - { url = "https://files.pythonhosted.org/packages/eb/5d/32b5c44ccf149a26623671df49cbfbd0a0ae511ff3df9d9d2426966a8d57/ruff-0.15.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b39329b60eba44156d138275323cc726bbfbddcec3063da57caa8a8b1d50adf", size = 10607625, upload-time = "2026-03-19T16:27:03.263Z" }, - { url = "https://files.pythonhosted.org/packages/5d/f1/f0001cabe86173aaacb6eb9bb734aa0605f9a6aa6fa7d43cb49cbc4af9c9/ruff-0.15.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87768c151808505f2bfc93ae44e5f9e7c8518943e5074f76ac21558ef5627c85", size = 10324743, upload-time = "2026-03-19T16:27:09.791Z" }, - { url = "https://files.pythonhosted.org/packages/7a/87/b8a8f3d56b8d848008559e7c9d8bf367934d5367f6d932ba779456e2f73b/ruff-0.15.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb0511670002c6c529ec66c0e30641c976c8963de26a113f3a30456b702468b0", size = 11138536, upload-time = "2026-03-19T16:27:06.101Z" }, - { url = "https://files.pythonhosted.org/packages/e4/f2/4fd0d05aab0c5934b2e1464784f85ba2eab9d54bffc53fb5430d1ed8b829/ruff-0.15.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0d19644f801849229db8345180a71bee5407b429dd217f853ec515e968a6912", size = 11994292, upload-time = "2026-03-19T16:26:48.718Z" }, - { url = "https://files.pythonhosted.org/packages/64/22/fc4483871e767e5e95d1622ad83dad5ebb830f762ed0420fde7dfa9d9b08/ruff-0.15.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4806d8e09ef5e84eb19ba833d0442f7e300b23fe3f0981cae159a248a10f0036", size = 11398981, upload-time = "2026-03-19T16:26:54.513Z" }, - { url = "https://files.pythonhosted.org/packages/b0/99/66f0343176d5eab02c3f7fcd2de7a8e0dd7a41f0d982bee56cd1c24db62b/ruff-0.15.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dce0896488562f09a27b9c91b1f58a097457143931f3c4d519690dea54e624c5", size = 11242422, upload-time = "2026-03-19T16:26:29.277Z" }, - { url = "https://files.pythonhosted.org/packages/5d/3a/a7060f145bfdcce4c987ea27788b30c60e2c81d6e9a65157ca8afe646328/ruff-0.15.7-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:1852ce241d2bc89e5dc823e03cff4ce73d816b5c6cdadd27dbfe7b03217d2a12", size = 11232158, upload-time = "2026-03-19T16:26:42.321Z" }, - { url = "https://files.pythonhosted.org/packages/a7/53/90fbb9e08b29c048c403558d3cdd0adf2668b02ce9d50602452e187cd4af/ruff-0.15.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5f3e4b221fb4bd293f79912fc5e93a9063ebd6d0dcbd528f91b89172a9b8436c", size = 10577861, upload-time = "2026-03-19T16:26:57.459Z" }, - { url = "https://files.pythonhosted.org/packages/2f/aa/5f486226538fe4d0f0439e2da1716e1acf895e2a232b26f2459c55f8ddad/ruff-0.15.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b15e48602c9c1d9bdc504b472e90b90c97dc7d46c7028011ae67f3861ceba7b4", size = 10327310, upload-time = "2026-03-19T16:26:35.909Z" }, - { url = "https://files.pythonhosted.org/packages/99/9e/271afdffb81fe7bfc8c43ba079e9d96238f674380099457a74ccb3863857/ruff-0.15.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1b4705e0e85cedc74b0a23cf6a179dbb3df184cb227761979cc76c0440b5ab0d", size = 10840752, upload-time = "2026-03-19T16:26:45.723Z" }, - { url = "https://files.pythonhosted.org/packages/bf/29/a4ae78394f76c7759953c47884eb44de271b03a66634148d9f7d11e721bd/ruff-0.15.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:112c1fa316a558bb34319282c1200a8bf0495f1b735aeb78bfcb2991e6087580", size = 11336961, upload-time = "2026-03-19T16:26:39.076Z" }, - { url = "https://files.pythonhosted.org/packages/26/6b/8786ba5736562220d588a2f6653e6c17e90c59ced34a2d7b512ef8956103/ruff-0.15.7-py3-none-win32.whl", hash = "sha256:6d39e2d3505b082323352f733599f28169d12e891f7dd407f2d4f54b4c2886de", size = 10582538, upload-time = "2026-03-19T16:26:15.992Z" }, - { url = "https://files.pythonhosted.org/packages/2b/e9/346d4d3fffc6871125e877dae8d9a1966b254fbd92a50f8561078b88b099/ruff-0.15.7-py3-none-win_amd64.whl", hash = "sha256:4d53d712ddebcd7dace1bc395367aec12c057aacfe9adbb6d832302575f4d3a1", size = 11755839, upload-time = "2026-03-19T16:26:19.897Z" }, - { url = "https://files.pythonhosted.org/packages/8f/e8/726643a3ea68c727da31570bde48c7a10f1aa60eddd628d94078fec586ff/ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2", size = 11023304, upload-time = "2026-03-19T16:26:51.669Z" }, + { url = "https://files.pythonhosted.org/packages/4a/92/c445b0cd6da6e7ae51e954939cb69f97e008dbe750cfca89b8cedc081be7/ruff-0.15.8-py3-none-linux_armv6l.whl", hash = "sha256:cbe05adeba76d58162762d6b239c9056f1a15a55bd4b346cfd21e26cd6ad7bc7", size = 10527394, upload-time = "2026-03-26T18:39:41.566Z" }, + { url = "https://files.pythonhosted.org/packages/eb/92/f1c662784d149ad1414cae450b082cf736430c12ca78367f20f5ed569d65/ruff-0.15.8-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d3e3d0b6ba8dca1b7ef9ab80a28e840a20070c4b62e56d675c24f366ef330570", size = 10905693, upload-time = "2026-03-26T18:39:30.364Z" }, + { url = "https://files.pythonhosted.org/packages/ca/f2/7a631a8af6d88bcef997eb1bf87cc3da158294c57044aafd3e17030613de/ruff-0.15.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6ee3ae5c65a42f273f126686353f2e08ff29927b7b7e203b711514370d500de3", size = 10323044, upload-time = "2026-03-26T18:39:33.37Z" }, + { url = "https://files.pythonhosted.org/packages/67/18/1bf38e20914a05e72ef3b9569b1d5c70a7ef26cd188d69e9ca8ef588d5bf/ruff-0.15.8-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdce027ada77baa448077ccc6ebb2fa9c3c62fd110d8659d601cf2f475858d94", size = 10629135, upload-time = "2026-03-26T18:39:44.142Z" }, + { url = "https://files.pythonhosted.org/packages/d2/e9/138c150ff9af60556121623d41aba18b7b57d95ac032e177b6a53789d279/ruff-0.15.8-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:12e617fc01a95e5821648a6df341d80456bd627bfab8a829f7cfc26a14a4b4a3", size = 10348041, upload-time = "2026-03-26T18:39:52.178Z" }, + { url = "https://files.pythonhosted.org/packages/02/f1/5bfb9298d9c323f842c5ddeb85f1f10ef51516ac7a34ba446c9347d898df/ruff-0.15.8-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:432701303b26416d22ba696c39f2c6f12499b89093b61360abc34bcc9bf07762", size = 11121987, upload-time = "2026-03-26T18:39:55.195Z" }, + { url = "https://files.pythonhosted.org/packages/10/11/6da2e538704e753c04e8d86b1fc55712fdbdcc266af1a1ece7a51fff0d10/ruff-0.15.8-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d910ae974b7a06a33a057cb87d2a10792a3b2b3b35e33d2699fdf63ec8f6b17a", size = 11951057, upload-time = "2026-03-26T18:39:19.18Z" }, + { url = "https://files.pythonhosted.org/packages/83/f0/c9208c5fd5101bf87002fed774ff25a96eea313d305f1e5d5744698dc314/ruff-0.15.8-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2033f963c43949d51e6fdccd3946633c6b37c484f5f98c3035f49c27395a8ab8", size = 11464613, upload-time = "2026-03-26T18:40:06.301Z" }, + { url = "https://files.pythonhosted.org/packages/f8/22/d7f2fabdba4fae9f3b570e5605d5eb4500dcb7b770d3217dca4428484b17/ruff-0.15.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f29b989a55572fb885b77464cf24af05500806ab4edf9a0fd8977f9759d85b1", size = 11257557, upload-time = "2026-03-26T18:39:57.972Z" }, + { url = "https://files.pythonhosted.org/packages/71/8c/382a9620038cf6906446b23ce8632ab8c0811b8f9d3e764f58bedd0c9a6f/ruff-0.15.8-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:ac51d486bf457cdc985a412fb1801b2dfd1bd8838372fc55de64b1510eff4bec", size = 11169440, upload-time = "2026-03-26T18:39:22.205Z" }, + { url = "https://files.pythonhosted.org/packages/4d/0d/0994c802a7eaaf99380085e4e40c845f8e32a562e20a38ec06174b52ef24/ruff-0.15.8-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:c9861eb959edab053c10ad62c278835ee69ca527b6dcd72b47d5c1e5648964f6", size = 10605963, upload-time = "2026-03-26T18:39:46.682Z" }, + { url = "https://files.pythonhosted.org/packages/19/aa/d624b86f5b0aad7cef6bbf9cd47a6a02dfdc4f72c92a337d724e39c9d14b/ruff-0.15.8-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8d9a5b8ea13f26ae90838afc33f91b547e61b794865374f114f349e9036835fb", size = 10357484, upload-time = "2026-03-26T18:39:49.176Z" }, + { url = "https://files.pythonhosted.org/packages/35/c3/e0b7835d23001f7d999f3895c6b569927c4d39912286897f625736e1fd04/ruff-0.15.8-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c2a33a529fb3cbc23a7124b5c6ff121e4d6228029cba374777bd7649cc8598b8", size = 10830426, upload-time = "2026-03-26T18:40:03.702Z" }, + { url = "https://files.pythonhosted.org/packages/f0/51/ab20b322f637b369383adc341d761eaaa0f0203d6b9a7421cd6e783d81b9/ruff-0.15.8-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:75e5cd06b1cf3f47a3996cfc999226b19aa92e7cce682dcd62f80d7035f98f49", size = 11345125, upload-time = "2026-03-26T18:39:27.799Z" }, + { url = "https://files.pythonhosted.org/packages/37/e6/90b2b33419f59d0f2c4c8a48a4b74b460709a557e8e0064cf33ad894f983/ruff-0.15.8-py3-none-win32.whl", hash = "sha256:bc1f0a51254ba21767bfa9a8b5013ca8149dcf38092e6a9eb704d876de94dc34", size = 10571959, upload-time = "2026-03-26T18:39:36.117Z" }, + { url = "https://files.pythonhosted.org/packages/1f/a2/ef467cb77099062317154c63f234b8a7baf7cb690b99af760c5b68b9ee7f/ruff-0.15.8-py3-none-win_amd64.whl", hash = "sha256:04f79eff02a72db209d47d665ba7ebcad609d8918a134f86cb13dd132159fc89", size = 11743893, upload-time = "2026-03-26T18:39:25.01Z" }, + { url = "https://files.pythonhosted.org/packages/15/e2/77be4fff062fa78d9b2a4dea85d14785dac5f1d0c1fb58ed52331f0ebe28/ruff-0.15.8-py3-none-win_arm64.whl", hash = "sha256:cf891fa8e3bb430c0e7fac93851a5978fc99c8fa2c053b57b118972866f8e5f2", size = 11048175, upload-time = "2026-03-26T18:40:01.06Z" }, ] [[package]] @@ -6402,7 +6429,7 @@ wheels = [ [[package]] name = "tablestore" -version = "6.4.1" +version = "6.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -6415,9 +6442,9 @@ dependencies = [ { name = "six" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/62/00/53f8eeb0016e7ad518f92b085de8855891d10581b42f86d15d1df7a56d33/tablestore-6.4.1.tar.gz", hash = "sha256:005c6939832f2ecd403e01220b7045de45f2e53f1ffaf0c2efc435810885fffb", size = 120319, upload-time = "2026-02-13T06:58:37.267Z" } +sdist = { url = "https://files.pythonhosted.org/packages/09/07/afa1d18521bab13bb813066892b73589937fcf68aea63a54b0b14dae17b5/tablestore-6.4.2.tar.gz", hash = "sha256:5251e14b7c7ebf3d49d37dde957b49c7dba04ee8715c2650109cc02f3b89cc77", size = 5071435, upload-time = "2026-03-26T15:39:06.498Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/96/a132bdecb753dc9dc34124a53019da29672baaa34485c8c504895897ea96/tablestore-6.4.1-py3-none-any.whl", hash = "sha256:616898d294dfe22f0d427463c241c6788374cdb2ace9aaf85673ce2c2a18d7e0", size = 141556, upload-time = "2026-02-13T06:58:35.579Z" }, + { url = "https://files.pythonhosted.org/packages/c7/3f/5fb3e8e5de36934fe38986b4e861657cebb3a6dfd97d32224cd40fc66359/tablestore-6.4.2-py3-none-any.whl", hash = "sha256:98c4cffa5eace4a3ea6fc2425263e733093c2baa43537f25dbaaf02e2b7882d8", size = 5114987, upload-time = "2026-03-26T15:39:04.074Z" }, ] [[package]] @@ -6443,7 +6470,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/20/81/be13f417065200182 [[package]] name = "tcvectordb" -version = "2.0.0" +version = "2.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cachetools" }, @@ -6456,9 +6483,9 @@ dependencies = [ { name = "ujson" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/16/21/3bcd466df20ac69408c0228b1c5e793cf3283085238d3ef5d352c556b6ad/tcvectordb-2.0.0.tar.gz", hash = "sha256:38c6ed17931b9bd702138941ca6cfe10b2b60301424ffa36b64a3c2686318941", size = 82209, upload-time = "2025-12-27T07:55:27.376Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/4c/3510489c20823c045a4f84c3f656b1af00b3fbbfa36efc494cf01492521f/tcvectordb-2.1.0.tar.gz", hash = "sha256:382615573f2b6d3e21535b686feac8895169b8eb56078fc73abb020676a1622f", size = 85691, upload-time = "2026-03-25T12:55:27.509Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/af/10/e807b273348edef3b321194bc13b67d2cd4df64e22f0404b9e39082415c7/tcvectordb-2.0.0-py3-none-any.whl", hash = "sha256:1731d9c6c0d17a4199872747ddfb1dd3feb26f14ffe7a657f8a5ac3af4ddcdd1", size = 96256, upload-time = "2025-12-27T07:55:24.362Z" }, + { url = "https://files.pythonhosted.org/packages/99/cf/7f340b4dc30ed0d2758915d1c2a4b2e9f0c90ce4f322b7cf17e571c80a45/tcvectordb-2.1.0-py3-none-any.whl", hash = "sha256:afbfc5f82bda70480921b2308148cbd0c51c8b45b3eef6cea64ddd003c7577e9", size = 99615, upload-time = "2026-03-25T12:55:26.004Z" }, ] [[package]] @@ -6850,18 +6877,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/61/91/915c4a6e6e9bd2bca3ec0c21c1771b175c59e204b85e57f3f572370fe753/types_jmespath-1.1.0.20260124-py3-none-any.whl", hash = "sha256:ec387666d446b15624215aa9cbd2867ffd885b6c74246d357c65e830c7a138b3", size = 11509, upload-time = "2026-01-24T03:18:45.536Z" }, ] -[[package]] -name = "types-jsonschema" -version = "4.26.0.20260202" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "referencing" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a1/07/68f63e715eb327ed2f5292e29e8be99785db0f72c7664d2c63bd4dbdc29d/types_jsonschema-4.26.0.20260202.tar.gz", hash = "sha256:29831baa4308865a9aec547a61797a06fc152b0dac8dddd531e002f32265cb07", size = 16168, upload-time = "2026-02-02T04:11:22.585Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/06/962d4f364f779d7389cd31a1bb581907b057f52f0ace2c119a8dd8409db6/types_jsonschema-4.26.0.20260202-py3-none-any.whl", hash = "sha256:41c95343abc4de9264e333a55e95dfb4d401e463856d0164eec9cb182e8746da", size = 15914, upload-time = "2026-02-02T04:11:21.61Z" }, -] - [[package]] name = "types-markdown" version = "3.10.2.20260211" @@ -6873,11 +6888,11 @@ wheels = [ [[package]] name = "types-oauthlib" -version = "3.3.0.20250822" +version = "3.3.0.20260324" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6a/6e/d08033f562053c459322333c46baa8cf8d2d8c18f30d46dd898c8fd8df77/types_oauthlib-3.3.0.20250822.tar.gz", hash = "sha256:2cd41587dd80c199e4230e3f086777e9ae525e89579c64afe5e0039ab09be9de", size = 25700, upload-time = "2025-08-22T03:02:41.378Z" } +sdist = { url = "https://files.pythonhosted.org/packages/91/38/543938f86d81bd6a78b8c355fe81bb8da0a26e4c28addfe3443e38a683d2/types_oauthlib-3.3.0.20260324.tar.gz", hash = "sha256:3c4cc07fa33886f881682237c1e445c5f1778b44efea118f4c1e4ede82cb52f2", size = 26030, upload-time = "2026-03-24T04:06:30.898Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/18/4b/00593b8b5d055550e1fcb9af2c42fa11b0a90bf16a94759a77bc1c3c0c72/types_oauthlib-3.3.0.20250822-py3-none-any.whl", hash = "sha256:b7f4c9b9eed0e020f454e0af800b10e93dd2efd196da65744b76910cce7e70d6", size = 48800, upload-time = "2025-08-22T03:02:40.427Z" }, + { url = "https://files.pythonhosted.org/packages/0e/60/26f0ddade4b2bb17b3d8f3ebaac436e5487caec28831da3d7ea309fe93b9/types_oauthlib-3.3.0.20260324-py3-none-any.whl", hash = "sha256:d24662033b04f4d50a2f1fed04c1b43ff2554aa037c1dafa0424f87100a46ccd", size = 48984, upload-time = "2026-03-24T04:06:29.696Z" }, ] [[package]] @@ -7028,11 +7043,11 @@ wheels = [ [[package]] name = "types-regex" -version = "2026.2.28.20260301" +version = "2026.3.32.20260329" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3a/ed/106958cb686316113b748ed4209fa363fd92b15759d5409c3930fed36606/types_regex-2026.2.28.20260301.tar.gz", hash = "sha256:644c231db3f368908320170c14905731a7ae5fabdac0f60f5d6d12ecdd3bc8dd", size = 13157, upload-time = "2026-03-01T04:11:13.559Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/d8/a3aca5775c573e56d201bbd76a827b84d851a4bce28e189e5acb9c7a0d15/types_regex-2026.3.32.20260329.tar.gz", hash = "sha256:12653e44694cb3e3ccdc39bab3d433d2a83fec1c01220e6871fd6f3cf434675c", size = 13111, upload-time = "2026-03-29T04:27:04.759Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c7/bb/9bc26fcf5155bd25efeca35f8ba6bffb8b3c9da2baac8bf40067606418f3/types_regex-2026.2.28.20260301-py3-none-any.whl", hash = "sha256:7da7a1fe67528238176a5844fd435ca90617cf605341308686afbc579fdea5c0", size = 11130, upload-time = "2026-03-01T04:11:11.454Z" }, + { url = "https://files.pythonhosted.org/packages/89/f4/a1db307e56753c49fb15fc88d70fadeb3f38897b28cab645cddd18054c79/types_regex-2026.3.32.20260329-py3-none-any.whl", hash = "sha256:861d0893bcfe08a57eb7486a502014e29dc2721d46dd5130798fbccafdb31cc0", size = 11128, upload-time = "2026-03-29T04:27:03.854Z" }, ] [[package]] @@ -7692,7 +7707,7 @@ wheels = [ [[package]] name = "xinference-client" -version = "2.3.1" +version = "2.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -7700,9 +7715,9 @@ dependencies = [ { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/bc/7a/33aeef9cffdc331de0046c25412622c5a16226d1b4e0cca9ed512ad00b9a/xinference_client-2.3.1.tar.gz", hash = "sha256:23ae225f47ff9adf4c6f7718c54993d1be8c704d727509f6e5cb670de3e02c4d", size = 58414, upload-time = "2026-03-15T05:53:23.994Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0e/f2/7640528fd4f816df19afe91d52332a658ad2d2bacb13471b0a27dbd0cf46/xinference_client-2.4.0.tar.gz", hash = "sha256:59de6d58f89126c8ff05136818e0756108e534858255d7c4c0673b804fd2d01d", size = 58386, upload-time = "2026-03-29T05:10:58.533Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/74/8d/d9ab0a457718050a279b9bb6515b7245d114118dc5e275f190ef2628dd16/xinference_client-2.3.1-py3-none-any.whl", hash = "sha256:f7c4f0b56635b46be9cfd9b2affa8e15275491597ac9b958e14b13da5745133e", size = 40012, upload-time = "2026-03-15T05:53:22.797Z" }, + { url = "https://files.pythonhosted.org/packages/73/cf/9d27e0095cc28691c73ff186b33556790c7b87f046ca2ecd517c80272592/xinference_client-2.4.0-py3-none-any.whl", hash = "sha256:2f9478b00fe15643f281fe4c0643e74479c8b7837d377000ff120702cda81efc", size = 40012, upload-time = "2026-03-29T05:10:57.279Z" }, ] [[package]] diff --git a/dev/pytest/pytest_unit_tests.sh b/dev/pytest/pytest_unit_tests.sh index a034083304..1d4ff4d86f 100755 --- a/dev/pytest/pytest_unit_tests.sh +++ b/dev/pytest/pytest_unit_tests.sh @@ -13,4 +13,4 @@ PYTEST_XDIST_ARGS="${PYTEST_XDIST_ARGS:--n auto}" pytest --timeout "${PYTEST_TIMEOUT}" ${PYTEST_XDIST_ARGS} api/tests/unit_tests --ignore=api/tests/unit_tests/controllers # Run controller tests sequentially to avoid import race conditions -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/unit_tests/controllers +pytest --timeout "${PYTEST_TIMEOUT}" --cov-append api/tests/unit_tests/controllers diff --git a/docker/.env.example b/docker/.env.example index 8cf77cf56b..9fbf9a9e72 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -488,7 +488,8 @@ ALIYUN_OSS_REGION=ap-southeast-1 ALIYUN_OSS_AUTH_VERSION=v4 # Don't start with '/'. OSS doesn't support leading slash in object names. ALIYUN_OSS_PATH=your-path -ALIYUN_CLOUDBOX_ID=your-cloudbox-id +# Optional CloudBox ID for Aliyun OSS, DO NOT enable it if you are not using CloudBox. +#ALIYUN_CLOUDBOX_ID=your-cloudbox-id # Tencent COS Configuration # diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 98c2613a07..e55cf942c3 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -275,6 +275,7 @@ services: # Use the shared environment variables. <<: *shared-api-worker-env DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin} + DB_SSL_MODE: ${DB_SSL_MODE:-disable} SERVER_PORT: ${PLUGIN_DAEMON_PORT:-5002} SERVER_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi} MAX_PLUGIN_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 1746bb567a..911da70a73 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -127,6 +127,8 @@ services: restart: always env_file: - ./middleware.env + extra_hosts: + - "host.docker.internal:host-gateway" environment: # Use the shared environment variables. LOG_OUTPUT_FORMAT: ${LOG_OUTPUT_FORMAT:-text} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 2a75de1a89..737a62020c 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -146,7 +146,6 @@ x-shared-env: &shared-api-worker-env ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1} ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path} - ALIYUN_CLOUDBOX_ID: ${ALIYUN_CLOUDBOX_ID:-your-cloudbox-id} TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-your-bucket-name} TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-your-secret-key} TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id} @@ -985,6 +984,7 @@ services: # Use the shared environment variables. <<: *shared-api-worker-env DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin} + DB_SSL_MODE: ${DB_SSL_MODE:-disable} SERVER_PORT: ${PLUGIN_DAEMON_PORT:-5002} SERVER_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi} MAX_PLUGIN_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} diff --git a/docs/ar-SA/README.md b/docs/ar-SA/README.md index 99e3e3567e..af5a9bfdc6 100644 --- a/docs/ar-SA/README.md +++ b/docs/ar-SA/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

diff --git a/docs/bn-BD/README.md b/docs/bn-BD/README.md index f3fa68b466..5dceacb187 100644 --- a/docs/bn-BD/README.md +++ b/docs/bn-BD/README.md @@ -57,7 +57,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

ডিফাই একটি ওপেন-সোর্স LLM অ্যাপ ডেভেলপমেন্ট প্ল্যাটফর্ম। এটি ইন্টুইটিভ ইন্টারফেস, এজেন্টিক AI ওয়ার্কফ্লো, RAG পাইপলাইন, এজেন্ট ক্যাপাবিলিটি, মডেল ম্যানেজমেন্ট, মনিটরিং সুবিধা এবং আরও অনেক কিছু একত্রিত করে, যা দ্রুত প্রোটোটাইপ থেকে প্রোডাকশন পর্যন্ত নিয়ে যেতে সহায়তা করে। diff --git a/docs/de-DE/README.md b/docs/de-DE/README.md index c71a0bfccf..1eab517a6d 100644 --- a/docs/de-DE/README.md +++ b/docs/de-DE/README.md @@ -57,7 +57,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify ist eine Open-Source-Plattform zur Entwicklung von LLM-Anwendungen. Ihre intuitive Benutzeroberfläche vereint agentenbasierte KI-Workflows, RAG-Pipelines, Agentenfunktionen, Modellverwaltung, Überwachungsfunktionen und mehr, sodass Sie schnell von einem Prototyp in die Produktion übergehen können. diff --git a/docs/es-ES/README.md b/docs/es-ES/README.md index da81b51d6a..f4c60e3d8f 100644 --- a/docs/es-ES/README.md +++ b/docs/es-ES/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

# diff --git a/docs/fr-FR/README.md b/docs/fr-FR/README.md index 291c8dab40..db8730b36b 100644 --- a/docs/fr-FR/README.md +++ b/docs/fr-FR/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

# diff --git a/docs/hi-IN/README.md b/docs/hi-IN/README.md index bedeaa6246..ad712046b5 100644 --- a/docs/hi-IN/README.md +++ b/docs/hi-IN/README.md @@ -58,6 +58,8 @@ README Tiếng Việt README in Deutsch README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা README in हिन्दी

diff --git a/docs/it-IT/README.md b/docs/it-IT/README.md index 2e96335d3e..bca560b574 100644 --- a/docs/it-IT/README.md +++ b/docs/it-IT/README.md @@ -58,7 +58,10 @@ README Tiếng Việt README in Deutsch README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify è una piattaforma open-source per lo sviluppo di applicazioni LLM. La sua interfaccia intuitiva combina flussi di lavoro AI basati su agenti, pipeline RAG, funzionalità di agenti, gestione dei modelli, funzionalità di monitoraggio e altro ancora, permettendovi di passare rapidamente da un prototipo alla produzione. diff --git a/docs/ja-JP/README.md b/docs/ja-JP/README.md index 659ffbda51..298dcb95aa 100644 --- a/docs/ja-JP/README.md +++ b/docs/ja-JP/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

# diff --git a/docs/ko-KR/README.md b/docs/ko-KR/README.md index 2f6c526ef2..2dcacaae8b 100644 --- a/docs/ko-KR/README.md +++ b/docs/ko-KR/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch - README in বাংলা + README in Italiano + README em Português do Brasil + README Slovenščina + README in বাংলা + README in हिन्दी

Dify는 오픈 소스 LLM 앱 개발 플랫폼입니다. 직관적인 인터페이스를 통해 AI 워크플로우, RAG 파이프라인, 에이전트 기능, 모델 관리, 관찰 기능 등을 결합하여 프로토타입에서 프로덕션까지 빠르게 전환할 수 있습니다. 주요 기능 목록은 다음과 같습니다:

diff --git a/docs/pt-BR/README.md b/docs/pt-BR/README.md index ed29ec0294..f7cc37a20f 100644 --- a/docs/pt-BR/README.md +++ b/docs/pt-BR/README.md @@ -58,7 +58,10 @@ README em Vietnamita README em Português - BR README in Deutsch + README in Italiano + README Slovenščina README in বাংলা + README in हिन्दी

Dify é uma plataforma de desenvolvimento de aplicativos LLM de código aberto. Sua interface intuitiva combina workflow de IA, pipeline RAG, capacidades de agente, gerenciamento de modelos, recursos de observabilidade e muito mais, permitindo que você vá rapidamente do protótipo à produção. Aqui está uma lista das principais funcionalidades: diff --git a/docs/sl-SI/README.md b/docs/sl-SI/README.md index caef2c303c..7b3fe76b5d 100644 --- a/docs/sl-SI/README.md +++ b/docs/sl-SI/README.md @@ -53,9 +53,12 @@ README بالعربية Türkçe README README Tiếng Việt - README Slovenščina README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify je odprtokodna platforma za razvoj aplikacij LLM. Njegov intuitivni vmesnik združuje agentski potek dela z umetno inteligenco, cevovod RAG, zmogljivosti agentov, upravljanje modelov, funkcije opazovanja in več, kar vam omogoča hiter prehod od prototipa do proizvodnje. diff --git a/docs/tlh/README.md b/docs/tlh/README.md index e2acd7734c..a8e63026c8 100644 --- a/docs/tlh/README.md +++ b/docs/tlh/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

# diff --git a/docs/tr-TR/README.md b/docs/tr-TR/README.md index 6361ca5dd9..cecc1b189c 100644 --- a/docs/tr-TR/README.md +++ b/docs/tr-TR/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify, açık kaynaklı bir LLM uygulama geliştirme platformudur. Sezgisel arayüzü, AI iş akışı, RAG pipeline'ı, ajan yetenekleri, model yönetimi, gözlemlenebilirlik özellikleri ve daha fazlasını birleştirerek, prototipten üretime hızlıca geçmenizi sağlar. İşte temel özelliklerin bir listesi: diff --git a/docs/vi-VN/README.md b/docs/vi-VN/README.md index 3042a98d95..5230d53110 100644 --- a/docs/vi-VN/README.md +++ b/docs/vi-VN/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify là một nền tảng phát triển ứng dụng LLM mã nguồn mở. Giao diện trực quan kết hợp quy trình làm việc AI, mô hình RAG, khả năng tác nhân, quản lý mô hình, tính năng quan sát và hơn thế nữa, cho phép bạn nhanh chóng chuyển từ nguyên mẫu sang sản phẩm. Đây là danh sách các tính năng cốt lõi: diff --git a/docs/zh-CN/README.md b/docs/zh-CN/README.md index 15bb447ad8..8ba8009959 100644 --- a/docs/zh-CN/README.md +++ b/docs/zh-CN/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी
# diff --git a/docs/zh-TW/README.md b/docs/zh-TW/README.md index 14b343ba29..de5bab8679 100644 --- a/docs/zh-TW/README.md +++ b/docs/zh-TW/README.md @@ -57,6 +57,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina + README in বাংলা + README in हिन्दी

Dify 是一個開源的 LLM 應用程式開發平台。其直觀的界面結合了智能代理工作流程、RAG 管道、代理功能、模型管理、可觀察性功能等,讓您能夠快速從原型進展到生產環境。 diff --git a/e2e/.gitignore b/e2e/.gitignore new file mode 100644 index 0000000000..96c1e0f3a1 --- /dev/null +++ b/e2e/.gitignore @@ -0,0 +1,6 @@ +node_modules/ +.auth/ +playwright-report/ +test-results/ +cucumber-report/ +.logs/ diff --git a/e2e/AGENTS.md b/e2e/AGENTS.md new file mode 100644 index 0000000000..245c9863d4 --- /dev/null +++ b/e2e/AGENTS.md @@ -0,0 +1,164 @@ +# E2E + +This package contains the repository-level end-to-end tests for Dify. + +This file is the canonical package guide for `e2e/`. Keep detailed workflow, architecture, debugging, and reporting documentation here. Keep `README.md` as a minimal pointer to this file so the two documents do not drift. + +The suite uses Cucumber for scenario definitions and Playwright as the browser execution layer. + +It tests: + +- backend API started from source +- frontend served from the production artifact +- middleware services started from Docker + +## Prerequisites + +- Node.js `^22.22.1` +- `pnpm` +- `uv` +- Docker + +Install Playwright browsers once: + +```bash +cd e2e +pnpm install +pnpm e2e:install +pnpm check +``` + +Use `pnpm check` as the default local verification step after editing E2E TypeScript, Cucumber support code, or feature glue. It runs formatting, linting, and type checks for this package. + +Common commands: + +```bash +# authenticated-only regression (default excludes @fresh) +# expects backend API, frontend artifact, and middleware stack to already be running +pnpm e2e + +# full reset + fresh install + authenticated scenarios +# starts required middleware/dependencies for you +pnpm e2e:full + +# run a tagged subset +pnpm e2e -- --tags @smoke + +# headed browser +pnpm e2e:headed -- --tags @smoke + +# slow down browser actions for local debugging +E2E_SLOW_MO=500 pnpm e2e:headed -- --tags @smoke +``` + +Frontend artifact behavior: + +- if `web/.next/BUILD_ID` exists, E2E reuses the existing build by default +- if you set `E2E_FORCE_WEB_BUILD=1`, E2E rebuilds the frontend before starting it + +## Lifecycle + +```mermaid +flowchart TD + A["Start E2E run"] --> B["run-cucumber.ts orchestrates setup/API/frontend"] + B --> C["support/web-server.ts starts or reuses frontend directly"] + C --> D["Cucumber loads config, steps, and support modules"] + D --> E["BeforeAll bootstraps shared auth state via /install"] + E --> F{"Which command is running?"} + F -->|`pnpm e2e`| G["Run config default tags: not @fresh and not @skip"] + F -->|`pnpm e2e:full*`| H["Override tags to not @skip"] + G --> I["Per-scenario BrowserContext from shared browser"] + H --> I + I --> J["Failure artifacts written to cucumber-report/artifacts"] +``` + +Ownership is split like this: + +- `scripts/setup.ts` is the single environment entrypoint for reset, middleware, backend, and frontend startup +- `run-cucumber.ts` orchestrates the E2E run and Cucumber invocation +- `support/web-server.ts` manages frontend reuse, startup, readiness, and shutdown +- `features/support/hooks.ts` manages auth bootstrap, scenario lifecycle, and diagnostics +- `features/support/world.ts` owns per-scenario typed context +- `features/step-definitions/` holds domain-oriented glue so the official VS Code Cucumber plugin works with default conventions when `e2e/` is opened as the workspace root + +Package layout: + +- `features/`: Gherkin scenarios grouped by capability +- `features/step-definitions/`: domain-oriented step definitions +- `features/support/hooks.ts`: suite lifecycle, auth-state bootstrap, diagnostics +- `features/support/world.ts`: shared scenario context +- `support/web-server.ts`: typed frontend startup/reuse logic +- `scripts/setup.ts`: reset and service lifecycle commands +- `scripts/run-cucumber.ts`: Cucumber orchestration entrypoint + +Behavior depends on instance state: + +- uninitialized instance: completes install and stores authenticated state +- initialized instance: signs in and reuses authenticated state + +Because of that, the `@fresh` install scenario only runs in the `pnpm e2e:full*` flows. The default `pnpm e2e*` flows exclude `@fresh` via Cucumber config tags so they can be re-run against an already initialized instance. + +Reset all persisted E2E state: + +```bash +pnpm e2e:reset +``` + +This removes: + +- `docker/volumes/db/data` +- `docker/volumes/redis/data` +- `docker/volumes/weaviate` +- `docker/volumes/plugin_daemon` +- `e2e/.auth` +- `e2e/.logs` +- `e2e/cucumber-report` + +Start the full middleware stack: + +```bash +pnpm e2e:middleware:up +``` + +Stop the full middleware stack: + +```bash +pnpm e2e:middleware:down +``` + +The middleware stack includes: + +- PostgreSQL +- Redis +- Weaviate +- Sandbox +- SSRF proxy +- Plugin daemon + +Fresh install verification: + +```bash +pnpm e2e:full +``` + +Run the Cucumber suite against an already running middleware stack: + +```bash +pnpm e2e:middleware:up +pnpm e2e +pnpm e2e:middleware:down +``` + +Artifacts and diagnostics: + +- `cucumber-report/report.html`: HTML report +- `cucumber-report/report.json`: JSON report +- `cucumber-report/artifacts/`: failure screenshots and HTML captures +- `.logs/cucumber-api.log`: backend startup log +- `.logs/cucumber-web.log`: frontend startup log + +Open the HTML report locally with: + +```bash +open cucumber-report/report.html +``` diff --git a/e2e/README.md b/e2e/README.md new file mode 100644 index 0000000000..9b4046eaff --- /dev/null +++ b/e2e/README.md @@ -0,0 +1,3 @@ +# E2E + +Canonical documentation for this package lives in [AGENTS.md](./AGENTS.md). diff --git a/e2e/cucumber.config.ts b/e2e/cucumber.config.ts new file mode 100644 index 0000000000..c162a6562e --- /dev/null +++ b/e2e/cucumber.config.ts @@ -0,0 +1,19 @@ +import type { IConfiguration } from '@cucumber/cucumber' + +const config = { + format: [ + 'progress-bar', + 'summary', + 'html:./cucumber-report/report.html', + 'json:./cucumber-report/report.json', + ], + import: ['features/**/*.ts'], + parallel: 1, + paths: ['features/**/*.feature'], + tags: process.env.E2E_CUCUMBER_TAGS || 'not @fresh and not @skip', + timeout: 60_000, +} satisfies Partial & { + timeout: number +} + +export default config diff --git a/e2e/features/apps/create-app.feature b/e2e/features/apps/create-app.feature new file mode 100644 index 0000000000..c0ca8ea4e0 --- /dev/null +++ b/e2e/features/apps/create-app.feature @@ -0,0 +1,10 @@ +@apps @authenticated +Feature: Create app + Scenario: Create a new blank app and redirect to the editor + Given I am signed in as the default E2E admin + When I open the apps console + And I start creating a blank app + And I enter a unique E2E app name + And I confirm app creation + Then I should land on the app editor + And I should see the "Orchestrate" text diff --git a/e2e/features/smoke/authenticated-entry.feature b/e2e/features/smoke/authenticated-entry.feature new file mode 100644 index 0000000000..3c1191a330 --- /dev/null +++ b/e2e/features/smoke/authenticated-entry.feature @@ -0,0 +1,8 @@ +@smoke @authenticated +Feature: Authenticated app console + Scenario: Open the apps console with the shared authenticated state + Given I am signed in as the default E2E admin + When I open the apps console + Then I should stay on the apps console + And I should see the "Create from Blank" button + And I should not see the "Sign in" button diff --git a/e2e/features/smoke/install.feature b/e2e/features/smoke/install.feature new file mode 100644 index 0000000000..39fc1f996b --- /dev/null +++ b/e2e/features/smoke/install.feature @@ -0,0 +1,7 @@ +@smoke @fresh +Feature: Fresh installation bootstrap + Scenario: Complete the initial installation bootstrap on a fresh instance + Given the last authentication bootstrap came from a fresh install + When I open the apps console + Then I should stay on the apps console + And I should see the "Create from Blank" button diff --git a/e2e/features/step-definitions/apps/create-app.steps.ts b/e2e/features/step-definitions/apps/create-app.steps.ts new file mode 100644 index 0000000000..b8e76c6f06 --- /dev/null +++ b/e2e/features/step-definitions/apps/create-app.steps.ts @@ -0,0 +1,29 @@ +import { Then, When } from '@cucumber/cucumber' +import { expect } from '@playwright/test' +import type { DifyWorld } from '../../support/world' + +When('I start creating a blank app', async function (this: DifyWorld) { + const page = this.getPage() + + await expect(page.getByRole('button', { name: 'Create from Blank' })).toBeVisible() + await page.getByRole('button', { name: 'Create from Blank' }).click() +}) + +When('I enter a unique E2E app name', async function (this: DifyWorld) { + const appName = `E2E App ${Date.now()}` + + await this.getPage().getByPlaceholder('Give your app a name').fill(appName) +}) + +When('I confirm app creation', async function (this: DifyWorld) { + const createButton = this.getPage() + .getByRole('button', { name: /^Create(?:\s|$)/ }) + .last() + + await expect(createButton).toBeEnabled() + await createButton.click() +}) + +Then('I should land on the app editor', async function (this: DifyWorld) { + await expect(this.getPage()).toHaveURL(/\/app\/[^/]+\/(workflow|configuration)(?:\?.*)?$/) +}) diff --git a/e2e/features/step-definitions/common/auth.steps.ts b/e2e/features/step-definitions/common/auth.steps.ts new file mode 100644 index 0000000000..bf03c2d8f4 --- /dev/null +++ b/e2e/features/step-definitions/common/auth.steps.ts @@ -0,0 +1,11 @@ +import { Given } from '@cucumber/cucumber' +import type { DifyWorld } from '../../support/world' + +Given('I am signed in as the default E2E admin', async function (this: DifyWorld) { + const session = await this.getAuthSession() + + this.attach( + `Authenticated as ${session.adminEmail} using ${session.mode} flow at ${session.baseURL}.`, + 'text/plain', + ) +}) diff --git a/e2e/features/step-definitions/common/navigation.steps.ts b/e2e/features/step-definitions/common/navigation.steps.ts new file mode 100644 index 0000000000..b18ff035fa --- /dev/null +++ b/e2e/features/step-definitions/common/navigation.steps.ts @@ -0,0 +1,23 @@ +import { Then, When } from '@cucumber/cucumber' +import { expect } from '@playwright/test' +import type { DifyWorld } from '../../support/world' + +When('I open the apps console', async function (this: DifyWorld) { + await this.getPage().goto('/apps') +}) + +Then('I should stay on the apps console', async function (this: DifyWorld) { + await expect(this.getPage()).toHaveURL(/\/apps(?:\?.*)?$/) +}) + +Then('I should see the {string} button', async function (this: DifyWorld, label: string) { + await expect(this.getPage().getByRole('button', { name: label })).toBeVisible() +}) + +Then('I should not see the {string} button', async function (this: DifyWorld, label: string) { + await expect(this.getPage().getByRole('button', { name: label })).not.toBeVisible() +}) + +Then('I should see the {string} text', async function (this: DifyWorld, text: string) { + await expect(this.getPage().getByText(text)).toBeVisible({ timeout: 30_000 }) +}) diff --git a/e2e/features/step-definitions/smoke/install.steps.ts b/e2e/features/step-definitions/smoke/install.steps.ts new file mode 100644 index 0000000000..857e01a971 --- /dev/null +++ b/e2e/features/step-definitions/smoke/install.steps.ts @@ -0,0 +1,12 @@ +import { Given } from '@cucumber/cucumber' +import { expect } from '@playwright/test' +import type { DifyWorld } from '../../support/world' + +Given( + 'the last authentication bootstrap came from a fresh install', + async function (this: DifyWorld) { + const session = await this.getAuthSession() + + expect(session.mode).toBe('install') + }, +) diff --git a/e2e/features/support/hooks.ts b/e2e/features/support/hooks.ts new file mode 100644 index 0000000000..a6862d79f5 --- /dev/null +++ b/e2e/features/support/hooks.ts @@ -0,0 +1,90 @@ +import { After, AfterAll, Before, BeforeAll, Status, setDefaultTimeout } from '@cucumber/cucumber' +import { chromium, type Browser } from '@playwright/test' +import { mkdir, writeFile } from 'node:fs/promises' +import path from 'node:path' +import { fileURLToPath } from 'node:url' +import { ensureAuthenticatedState } from '../../fixtures/auth' +import { baseURL, cucumberHeadless, cucumberSlowMo } from '../../test-env' +import type { DifyWorld } from './world' + +const e2eRoot = fileURLToPath(new URL('../..', import.meta.url)) +const artifactsDir = path.join(e2eRoot, 'cucumber-report', 'artifacts') + +let browser: Browser | undefined + +setDefaultTimeout(60_000) + +const sanitizeForPath = (value: string) => + value.replaceAll(/[^a-zA-Z0-9_-]+/g, '-').replaceAll(/^-+|-+$/g, '') + +const writeArtifact = async ( + scenarioName: string, + extension: 'html' | 'png', + contents: Buffer | string, +) => { + const artifactPath = path.join( + artifactsDir, + `${Date.now()}-${sanitizeForPath(scenarioName || 'scenario')}.${extension}`, + ) + await writeFile(artifactPath, contents) + + return artifactPath +} + +BeforeAll(async () => { + await mkdir(artifactsDir, { recursive: true }) + + browser = await chromium.launch({ + headless: cucumberHeadless, + slowMo: cucumberSlowMo, + }) + + console.log(`[e2e] session cache bootstrap against ${baseURL}`) + await ensureAuthenticatedState(browser, baseURL) +}) + +Before(async function (this: DifyWorld, { pickle }) { + if (!browser) throw new Error('Shared Playwright browser is not available.') + + await this.startAuthenticatedSession(browser) + this.scenarioStartedAt = Date.now() + + const tags = pickle.tags.map((tag) => tag.name).join(' ') + console.log(`[e2e] start ${pickle.name}${tags ? ` ${tags}` : ''}`) +}) + +After(async function (this: DifyWorld, { pickle, result }) { + const elapsedMs = this.scenarioStartedAt ? Date.now() - this.scenarioStartedAt : undefined + + if (result?.status !== Status.PASSED && this.page) { + const screenshot = await this.page.screenshot({ + fullPage: true, + }) + const screenshotPath = await writeArtifact(pickle.name, 'png', screenshot) + this.attach(screenshot, 'image/png') + + const html = await this.page.content() + const htmlPath = await writeArtifact(pickle.name, 'html', html) + this.attach(html, 'text/html') + + if (this.consoleErrors.length > 0) + this.attach(`Console Errors:\n${this.consoleErrors.join('\n')}`, 'text/plain') + + if (this.pageErrors.length > 0) + this.attach(`Page Errors:\n${this.pageErrors.join('\n')}`, 'text/plain') + + this.attach(`Artifacts:\n${[screenshotPath, htmlPath].join('\n')}`, 'text/plain') + } + + const status = result?.status || 'UNKNOWN' + console.log( + `[e2e] end ${pickle.name} status=${status}${elapsedMs ? ` durationMs=${elapsedMs}` : ''}`, + ) + + await this.closeSession() +}) + +AfterAll(async () => { + await browser?.close() + browser = undefined +}) diff --git a/e2e/features/support/world.ts b/e2e/features/support/world.ts new file mode 100644 index 0000000000..15ab8daf16 --- /dev/null +++ b/e2e/features/support/world.ts @@ -0,0 +1,68 @@ +import { type IWorldOptions, World, setWorldConstructor } from '@cucumber/cucumber' +import type { Browser, BrowserContext, ConsoleMessage, Page } from '@playwright/test' +import { + authStatePath, + readAuthSessionMetadata, + type AuthSessionMetadata, +} from '../../fixtures/auth' +import { baseURL, defaultLocale } from '../../test-env' + +export class DifyWorld extends World { + context: BrowserContext | undefined + page: Page | undefined + consoleErrors: string[] = [] + pageErrors: string[] = [] + scenarioStartedAt: number | undefined + session: AuthSessionMetadata | undefined + + constructor(options: IWorldOptions) { + super(options) + this.resetScenarioState() + } + + resetScenarioState() { + this.consoleErrors = [] + this.pageErrors = [] + } + + async startAuthenticatedSession(browser: Browser) { + this.resetScenarioState() + this.context = await browser.newContext({ + baseURL, + locale: defaultLocale, + storageState: authStatePath, + }) + this.context.setDefaultTimeout(30_000) + this.page = await this.context.newPage() + this.page.setDefaultTimeout(30_000) + + this.page.on('console', (message: ConsoleMessage) => { + if (message.type() === 'error') this.consoleErrors.push(message.text()) + }) + this.page.on('pageerror', (error) => { + this.pageErrors.push(error.message) + }) + } + + getPage() { + if (!this.page) throw new Error('Playwright page has not been initialized for this scenario.') + + return this.page + } + + async getAuthSession() { + this.session ??= await readAuthSessionMetadata() + return this.session + } + + async closeSession() { + await this.context?.close() + this.context = undefined + this.page = undefined + this.session = undefined + this.scenarioStartedAt = undefined + this.resetScenarioState() + } +} + +setWorldConstructor(DifyWorld) diff --git a/e2e/fixtures/auth.ts b/e2e/fixtures/auth.ts new file mode 100644 index 0000000000..853bfff5ed --- /dev/null +++ b/e2e/fixtures/auth.ts @@ -0,0 +1,148 @@ +import type { Browser, Page } from '@playwright/test' +import { expect } from '@playwright/test' +import { mkdir, readFile, writeFile } from 'node:fs/promises' +import path from 'node:path' +import { fileURLToPath } from 'node:url' +import { defaultBaseURL, defaultLocale } from '../test-env' + +export type AuthSessionMetadata = { + adminEmail: string + baseURL: string + mode: 'install' | 'login' + usedInitPassword: boolean +} + +const WAIT_TIMEOUT_MS = 120_000 +const e2eRoot = fileURLToPath(new URL('..', import.meta.url)) + +export const authDir = path.join(e2eRoot, '.auth') +export const authStatePath = path.join(authDir, 'admin.json') +export const authMetadataPath = path.join(authDir, 'session.json') + +export const adminCredentials = { + email: process.env.E2E_ADMIN_EMAIL || 'e2e-admin@example.com', + name: process.env.E2E_ADMIN_NAME || 'E2E Admin', + password: process.env.E2E_ADMIN_PASSWORD || 'E2eAdmin12345', +} + +const initPassword = process.env.E2E_INIT_PASSWORD || 'E2eInit12345' + +export const resolveBaseURL = (configuredBaseURL?: string) => + configuredBaseURL || process.env.E2E_BASE_URL || defaultBaseURL + +export const readAuthSessionMetadata = async () => { + const content = await readFile(authMetadataPath, 'utf8') + return JSON.parse(content) as AuthSessionMetadata +} + +const escapeRegex = (value: string) => value.replaceAll(/[.*+?^${}()|[\]\\]/g, '\\$&') + +const appURL = (baseURL: string, pathname: string) => new URL(pathname, baseURL).toString() + +const waitForPageState = async (page: Page) => { + const installHeading = page.getByRole('heading', { name: 'Setting up an admin account' }) + const signInButton = page.getByRole('button', { name: 'Sign in' }) + const initPasswordField = page.getByLabel('Admin initialization password') + + const deadline = Date.now() + WAIT_TIMEOUT_MS + + while (Date.now() < deadline) { + if (await installHeading.isVisible().catch(() => false)) return 'install' as const + if (await signInButton.isVisible().catch(() => false)) return 'login' as const + if (await initPasswordField.isVisible().catch(() => false)) return 'init' as const + + await page.waitForTimeout(1_000) + } + + throw new Error(`Unable to determine auth page state for ${page.url()}`) +} + +const completeInitPasswordIfNeeded = async (page: Page) => { + const initPasswordField = page.getByLabel('Admin initialization password') + if (!(await initPasswordField.isVisible({ timeout: 3_000 }).catch(() => false))) return false + + await initPasswordField.fill(initPassword) + await page.getByRole('button', { name: 'Validate' }).click() + await expect(page.getByRole('heading', { name: 'Setting up an admin account' })).toBeVisible({ + timeout: WAIT_TIMEOUT_MS, + }) + + return true +} + +const completeInstall = async (page: Page, baseURL: string) => { + await expect(page.getByRole('heading', { name: 'Setting up an admin account' })).toBeVisible({ + timeout: WAIT_TIMEOUT_MS, + }) + + await page.getByLabel('Email address').fill(adminCredentials.email) + await page.getByLabel('Username').fill(adminCredentials.name) + await page.getByLabel('Password').fill(adminCredentials.password) + await page.getByRole('button', { name: 'Set up' }).click() + + await expect(page).toHaveURL(new RegExp(`^${escapeRegex(baseURL)}/apps(?:\\?.*)?$`), { + timeout: WAIT_TIMEOUT_MS, + }) +} + +const completeLogin = async (page: Page, baseURL: string) => { + await expect(page.getByRole('button', { name: 'Sign in' })).toBeVisible({ + timeout: WAIT_TIMEOUT_MS, + }) + + await page.getByLabel('Email address').fill(adminCredentials.email) + await page.getByLabel('Password').fill(adminCredentials.password) + await page.getByRole('button', { name: 'Sign in' }).click() + + await expect(page).toHaveURL(new RegExp(`^${escapeRegex(baseURL)}/apps(?:\\?.*)?$`), { + timeout: WAIT_TIMEOUT_MS, + }) +} + +export const ensureAuthenticatedState = async (browser: Browser, configuredBaseURL?: string) => { + const baseURL = resolveBaseURL(configuredBaseURL) + + await mkdir(authDir, { recursive: true }) + + const context = await browser.newContext({ + baseURL, + locale: defaultLocale, + }) + const page = await context.newPage() + + try { + await page.goto(appURL(baseURL, '/install'), { waitUntil: 'networkidle' }) + + let usedInitPassword = await completeInitPasswordIfNeeded(page) + let pageState = await waitForPageState(page) + + while (pageState === 'init') { + const completedInitPassword = await completeInitPasswordIfNeeded(page) + if (!completedInitPassword) + throw new Error(`Unable to validate initialization password for ${page.url()}`) + + usedInitPassword = true + pageState = await waitForPageState(page) + } + + if (pageState === 'install') await completeInstall(page, baseURL) + else await completeLogin(page, baseURL) + + await expect(page.getByRole('button', { name: 'Create from Blank' })).toBeVisible({ + timeout: WAIT_TIMEOUT_MS, + }) + + await context.storageState({ path: authStatePath }) + + const metadata: AuthSessionMetadata = { + adminEmail: adminCredentials.email, + baseURL, + mode: pageState, + usedInitPassword, + } + + await writeFile(authMetadataPath, `${JSON.stringify(metadata, null, 2)}\n`, 'utf8') + } finally { + await context.close() + } +} diff --git a/e2e/package.json b/e2e/package.json new file mode 100644 index 0000000000..9b8a1f873f --- /dev/null +++ b/e2e/package.json @@ -0,0 +1,34 @@ +{ + "name": "dify-e2e", + "private": true, + "type": "module", + "scripts": { + "check": "vp check --fix", + "e2e": "tsx ./scripts/run-cucumber.ts", + "e2e:full": "tsx ./scripts/run-cucumber.ts --full", + "e2e:full:headed": "tsx ./scripts/run-cucumber.ts --full --headed", + "e2e:headed": "tsx ./scripts/run-cucumber.ts --headed", + "e2e:install": "playwright install --with-deps chromium", + "e2e:middleware:down": "tsx ./scripts/setup.ts middleware-down", + "e2e:middleware:up": "tsx ./scripts/setup.ts middleware-up", + "e2e:reset": "tsx ./scripts/setup.ts reset" + }, + "devDependencies": { + "@cucumber/cucumber": "12.7.0", + "@playwright/test": "1.51.1", + "@types/node": "25.5.0", + "tsx": "4.21.0", + "typescript": "5.9.3", + "vite-plus": "latest" + }, + "engines": { + "node": "^22.22.1" + }, + "packageManager": "pnpm@10.32.1", + "pnpm": { + "overrides": { + "vite": "npm:@voidzero-dev/vite-plus-core@latest", + "vitest": "npm:@voidzero-dev/vite-plus-test@latest" + } + } +} diff --git a/e2e/pnpm-lock.yaml b/e2e/pnpm-lock.yaml new file mode 100644 index 0000000000..b63458ad4a --- /dev/null +++ b/e2e/pnpm-lock.yaml @@ -0,0 +1,2632 @@ +lockfileVersion: '9.0' + +settings: + autoInstallPeers: true + excludeLinksFromLockfile: false + +overrides: + vite: npm:@voidzero-dev/vite-plus-core@latest + vitest: npm:@voidzero-dev/vite-plus-test@latest + +importers: + + .: + devDependencies: + '@cucumber/cucumber': + specifier: 12.7.0 + version: 12.7.0 + '@playwright/test': + specifier: 1.51.1 + version: 1.51.1 + '@types/node': + specifier: 25.5.0 + version: 25.5.0 + tsx: + specifier: 4.21.0 + version: 4.21.0 + typescript: + specifier: 5.9.3 + version: 5.9.3 + vite-plus: + specifier: latest + version: 0.1.14(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3) + +packages: + + '@babel/code-frame@7.29.0': + resolution: {integrity: sha512-9NhCeYjq9+3uxgdtp20LSiJXJvN0FeCtNGpJxuMFZ1Kv3cWUNb6DOhJwUvcVCzKGR66cw4njwM6hrJLqgOwbcw==} + engines: {node: '>=6.9.0'} + + '@babel/helper-validator-identifier@7.28.5': + resolution: {integrity: sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==} + engines: {node: '>=6.9.0'} + + '@colors/colors@1.5.0': + resolution: {integrity: sha512-ooWCrlZP11i8GImSjTHYHLkvFDP48nS4+204nGb1RiX/WXYHmJA2III9/e2DWVabCESdW7hBAEzHRqUn9OUVvQ==} + engines: {node: '>=0.1.90'} + + '@cucumber/ci-environment@13.0.0': + resolution: {integrity: sha512-cs+3NzfNkGbcmHPddjEv4TKFiBpZRQ6WJEEufB9mw+ExS22V/4R/zpDSEG+fsJ/iSNCd6A2sATdY8PFOyY3YnA==} + + '@cucumber/cucumber-expressions@19.0.0': + resolution: {integrity: sha512-4FKoOQh2Uf6F6/Ln+1OxuK8LkTg6PyAqekhf2Ix8zqV2M54sH+m7XNJNLhOFOAW/t9nxzRbw2CcvXbCLjcvHZg==} + + '@cucumber/cucumber@12.7.0': + resolution: {integrity: sha512-7A/9CJpJDxv1SQ7hAZU0zPn2yRxx6XMR+LO4T94Enm3cYNWsEEj+RGX38NLX4INT+H6w5raX3Csb/qs4vUBsOA==} + engines: {node: 20 || 22 || >=24} + hasBin: true + + '@cucumber/gherkin-streams@6.0.0': + resolution: {integrity: sha512-HLSHMmdDH0vCr7vsVEURcDA4WwnRLdjkhqr6a4HQ3i4RFK1wiDGPjBGVdGJLyuXuRdJpJbFc6QxHvT8pU4t6jw==} + hasBin: true + peerDependencies: + '@cucumber/gherkin': '>=22.0.0' + '@cucumber/message-streams': '>=4.0.0' + '@cucumber/messages': '>=17.1.1' + + '@cucumber/gherkin-utils@11.0.0': + resolution: {integrity: sha512-LJ+s4+TepHTgdKWDR4zbPyT7rQjmYIcukTwNbwNwgqr6i8Gjcmzf6NmtbYDA19m1ZFg6kWbFsmHnj37ZuX+kZA==} + hasBin: true + + '@cucumber/gherkin@38.0.0': + resolution: {integrity: sha512-duEXK+KDfQUzu3vsSzXjkxQ2tirF5PRsc1Xrts6THKHJO6mjw4RjM8RV+vliuDasmhhrmdLcOcM7d9nurNTJKw==} + + '@cucumber/html-formatter@23.0.0': + resolution: {integrity: sha512-WwcRzdM8Ixy4e53j+Frm3fKM5rNuIyWUfy4HajEN+Xk/YcjA6yW0ACGTFDReB++VDZz/iUtwYdTlPRY36NbqJg==} + peerDependencies: + '@cucumber/messages': '>=18' + + '@cucumber/junit-xml-formatter@0.9.0': + resolution: {integrity: sha512-WF+A7pBaXpKMD1i7K59Nk5519zj4extxY4+4nSgv5XLsGXHDf1gJnb84BkLUzevNtp2o2QzMG0vWLwSm8V5blw==} + peerDependencies: + '@cucumber/messages': '*' + + '@cucumber/message-streams@4.0.1': + resolution: {integrity: sha512-Kxap9uP5jD8tHUZVjTWgzxemi/0uOsbGjd4LBOSxcJoOCRbESFwemUzilJuzNTB8pcTQUh8D5oudUyxfkJOKmA==} + peerDependencies: + '@cucumber/messages': '>=17.1.1' + + '@cucumber/messages@32.0.1': + resolution: {integrity: sha512-1OSoW+GQvFUNAl6tdP2CTBexTXMNJF0094goVUcvugtQeXtJ0K8sCP0xbq7GGoiezs/eJAAOD03+zAPT64orHQ==} + + '@cucumber/pretty-formatter@1.0.1': + resolution: {integrity: sha512-A1lU4VVP0aUWdOTmpdzvXOyEYuPtBDI0xYwYJnmoMDplzxMdhcHk86lyyvYDoMoPzzq6OkOE3isuosvUU4X7IQ==} + peerDependencies: + '@cucumber/cucumber': '>=7.0.0' + '@cucumber/messages': '*' + + '@cucumber/query@14.7.0': + resolution: {integrity: sha512-fiqZ4gMEgYjmbuWproF/YeCdD5y+gD2BqgBIGbpihOsx6UlNsyzoDSfO+Tny0q65DxfK+pHo2UkPyEl7dO7wmQ==} + peerDependencies: + '@cucumber/messages': '*' + + '@cucumber/tag-expressions@9.1.0': + resolution: {integrity: sha512-bvHjcRFZ+J1TqIa9eFNO1wGHqwx4V9ZKV3hYgkuK/VahHx73uiP4rKV3JVrvWSMrwrFvJG6C8aEwnCWSvbyFdQ==} + + '@emnapi/core@1.9.1': + resolution: {integrity: sha512-mukuNALVsoix/w1BJwFzwXBN/dHeejQtuVzcDsfOEsdpCumXb/E9j8w11h5S54tT1xhifGfbbSm/ICrObRb3KA==} + + '@emnapi/runtime@1.9.1': + resolution: {integrity: sha512-VYi5+ZVLhpgK4hQ0TAjiQiZ6ol0oe4mBx7mVv7IflsiEp0OWoVsp/+f9Vc1hOhE0TtkORVrI1GvzyreqpgWtkA==} + + '@emnapi/wasi-threads@1.2.0': + resolution: {integrity: sha512-N10dEJNSsUx41Z6pZsXU8FjPjpBEplgH24sfkmITrBED1/U2Esum9F3lfLrMjKHHjmi557zQn7kR9R+XWXu5Rg==} + + '@esbuild/aix-ppc64@0.27.4': + resolution: {integrity: sha512-cQPwL2mp2nSmHHJlCyoXgHGhbEPMrEEU5xhkcy3Hs/O7nGZqEpZ2sUtLaL9MORLtDfRvVl2/3PAuEkYZH0Ty8Q==} + engines: {node: '>=18'} + cpu: [ppc64] + os: [aix] + + '@esbuild/android-arm64@0.27.4': + resolution: {integrity: sha512-gdLscB7v75wRfu7QSm/zg6Rx29VLdy9eTr2t44sfTW7CxwAtQghZ4ZnqHk3/ogz7xao0QAgrkradbBzcqFPasw==} + engines: {node: '>=18'} + cpu: [arm64] + os: [android] + + '@esbuild/android-arm@0.27.4': + resolution: {integrity: sha512-X9bUgvxiC8CHAGKYufLIHGXPJWnr0OCdR0anD2e21vdvgCI8lIfqFbnoeOz7lBjdrAGUhqLZLcQo6MLhTO2DKQ==} + engines: {node: '>=18'} + cpu: [arm] + os: [android] + + '@esbuild/android-x64@0.27.4': + resolution: {integrity: sha512-PzPFnBNVF292sfpfhiyiXCGSn9HZg5BcAz+ivBuSsl6Rk4ga1oEXAamhOXRFyMcjwr2DVtm40G65N3GLeH1Lvw==} + engines: {node: '>=18'} + cpu: [x64] + os: [android] + + '@esbuild/darwin-arm64@0.27.4': + resolution: {integrity: sha512-b7xaGIwdJlht8ZFCvMkpDN6uiSmnxxK56N2GDTMYPr2/gzvfdQN8rTfBsvVKmIVY/X7EM+/hJKEIbbHs9oA4tQ==} + engines: {node: '>=18'} + cpu: [arm64] + os: [darwin] + + '@esbuild/darwin-x64@0.27.4': + resolution: {integrity: sha512-sR+OiKLwd15nmCdqpXMnuJ9W2kpy0KigzqScqHI3Hqwr7IXxBp3Yva+yJwoqh7rE8V77tdoheRYataNKL4QrPw==} + engines: {node: '>=18'} + cpu: [x64] + os: [darwin] + + '@esbuild/freebsd-arm64@0.27.4': + resolution: {integrity: sha512-jnfpKe+p79tCnm4GVav68A7tUFeKQwQyLgESwEAUzyxk/TJr4QdGog9sqWNcUbr/bZt/O/HXouspuQDd9JxFSw==} + engines: {node: '>=18'} + cpu: [arm64] + os: [freebsd] + + '@esbuild/freebsd-x64@0.27.4': + resolution: {integrity: sha512-2kb4ceA/CpfUrIcTUl1wrP/9ad9Atrp5J94Lq69w7UwOMolPIGrfLSvAKJp0RTvkPPyn6CIWrNy13kyLikZRZQ==} + engines: {node: '>=18'} + cpu: [x64] + os: [freebsd] + + '@esbuild/linux-arm64@0.27.4': + resolution: {integrity: sha512-7nQOttdzVGth1iz57kxg9uCz57dxQLHWxopL6mYuYthohPKEK0vU0C3O21CcBK6KDlkYVcnDXY099HcCDXd9dA==} + engines: {node: '>=18'} + cpu: [arm64] + os: [linux] + + '@esbuild/linux-arm@0.27.4': + resolution: {integrity: sha512-aBYgcIxX/wd5n2ys0yESGeYMGF+pv6g0DhZr3G1ZG4jMfruU9Tl1i2Z+Wnj9/KjGz1lTLCcorqE2viePZqj4Eg==} + engines: {node: '>=18'} + cpu: [arm] + os: [linux] + + '@esbuild/linux-ia32@0.27.4': + resolution: {integrity: sha512-oPtixtAIzgvzYcKBQM/qZ3R+9TEUd1aNJQu0HhGyqtx6oS7qTpvjheIWBbes4+qu1bNlo2V4cbkISr8q6gRBFA==} + engines: {node: '>=18'} + cpu: [ia32] + os: [linux] + + '@esbuild/linux-loong64@0.27.4': + resolution: {integrity: sha512-8mL/vh8qeCoRcFH2nM8wm5uJP+ZcVYGGayMavi8GmRJjuI3g1v6Z7Ni0JJKAJW+m0EtUuARb6Lmp4hMjzCBWzA==} + engines: {node: '>=18'} + cpu: [loong64] + os: [linux] + + '@esbuild/linux-mips64el@0.27.4': + resolution: {integrity: sha512-1RdrWFFiiLIW7LQq9Q2NES+HiD4NyT8Itj9AUeCl0IVCA459WnPhREKgwrpaIfTOe+/2rdntisegiPWn/r/aAw==} + engines: {node: '>=18'} + cpu: [mips64el] + os: [linux] + + '@esbuild/linux-ppc64@0.27.4': + resolution: {integrity: sha512-tLCwNG47l3sd9lpfyx9LAGEGItCUeRCWeAx6x2Jmbav65nAwoPXfewtAdtbtit/pJFLUWOhpv0FpS6GQAmPrHA==} + engines: {node: '>=18'} + cpu: [ppc64] + os: [linux] + + '@esbuild/linux-riscv64@0.27.4': + resolution: {integrity: sha512-BnASypppbUWyqjd1KIpU4AUBiIhVr6YlHx/cnPgqEkNoVOhHg+YiSVxM1RLfiy4t9cAulbRGTNCKOcqHrEQLIw==} + engines: {node: '>=18'} + cpu: [riscv64] + os: [linux] + + '@esbuild/linux-s390x@0.27.4': + resolution: {integrity: sha512-+eUqgb/Z7vxVLezG8bVB9SfBie89gMueS+I0xYh2tJdw3vqA/0ImZJ2ROeWwVJN59ihBeZ7Tu92dF/5dy5FttA==} + engines: {node: '>=18'} + cpu: [s390x] + os: [linux] + + '@esbuild/linux-x64@0.27.4': + resolution: {integrity: sha512-S5qOXrKV8BQEzJPVxAwnryi2+Iq5pB40gTEIT69BQONqR7JH1EPIcQ/Uiv9mCnn05jff9umq/5nqzxlqTOg9NA==} + engines: {node: '>=18'} + cpu: [x64] + os: [linux] + + '@esbuild/netbsd-arm64@0.27.4': + resolution: {integrity: sha512-xHT8X4sb0GS8qTqiwzHqpY00C95DPAq7nAwX35Ie/s+LO9830hrMd3oX0ZMKLvy7vsonee73x0lmcdOVXFzd6Q==} + engines: {node: '>=18'} + cpu: [arm64] + os: [netbsd] + + '@esbuild/netbsd-x64@0.27.4': + resolution: {integrity: sha512-RugOvOdXfdyi5Tyv40kgQnI0byv66BFgAqjdgtAKqHoZTbTF2QqfQrFwa7cHEORJf6X2ht+l9ABLMP0dnKYsgg==} + engines: {node: '>=18'} + cpu: [x64] + os: [netbsd] + + '@esbuild/openbsd-arm64@0.27.4': + resolution: {integrity: sha512-2MyL3IAaTX+1/qP0O1SwskwcwCoOI4kV2IBX1xYnDDqthmq5ArrW94qSIKCAuRraMgPOmG0RDTA74mzYNQA9ow==} + engines: {node: '>=18'} + cpu: [arm64] + os: [openbsd] + + '@esbuild/openbsd-x64@0.27.4': + resolution: {integrity: sha512-u8fg/jQ5aQDfsnIV6+KwLOf1CmJnfu1ShpwqdwC0uA7ZPwFws55Ngc12vBdeUdnuWoQYx/SOQLGDcdlfXhYmXQ==} + engines: {node: '>=18'} + cpu: [x64] + os: [openbsd] + + '@esbuild/openharmony-arm64@0.27.4': + resolution: {integrity: sha512-JkTZrl6VbyO8lDQO3yv26nNr2RM2yZzNrNHEsj9bm6dOwwu9OYN28CjzZkH57bh4w0I2F7IodpQvUAEd1mbWXg==} + engines: {node: '>=18'} + cpu: [arm64] + os: [openharmony] + + '@esbuild/sunos-x64@0.27.4': + resolution: {integrity: sha512-/gOzgaewZJfeJTlsWhvUEmUG4tWEY2Spp5M20INYRg2ZKl9QPO3QEEgPeRtLjEWSW8FilRNacPOg8R1uaYkA6g==} + engines: {node: '>=18'} + cpu: [x64] + os: [sunos] + + '@esbuild/win32-arm64@0.27.4': + resolution: {integrity: sha512-Z9SExBg2y32smoDQdf1HRwHRt6vAHLXcxD2uGgO/v2jK7Y718Ix4ndsbNMU/+1Qiem9OiOdaqitioZwxivhXYg==} + engines: {node: '>=18'} + cpu: [arm64] + os: [win32] + + '@esbuild/win32-ia32@0.27.4': + resolution: {integrity: sha512-DAyGLS0Jz5G5iixEbMHi5KdiApqHBWMGzTtMiJ72ZOLhbu/bzxgAe8Ue8CTS3n3HbIUHQz/L51yMdGMeoxXNJw==} + engines: {node: '>=18'} + cpu: [ia32] + os: [win32] + + '@esbuild/win32-x64@0.27.4': + resolution: {integrity: sha512-+knoa0BDoeXgkNvvV1vvbZX4+hizelrkwmGJBdT17t8FNPwG2lKemmuMZlmaNQ3ws3DKKCxpb4zRZEIp3UxFCg==} + engines: {node: '>=18'} + cpu: [x64] + os: [win32] + + '@napi-rs/wasm-runtime@1.1.1': + resolution: {integrity: sha512-p64ah1M1ld8xjWv3qbvFwHiFVWrq1yFvV4f7w+mzaqiR4IlSgkqhcRdHwsGgomwzBH51sRY4NEowLxnaBjcW/A==} + + '@oxc-project/runtime@0.121.0': + resolution: {integrity: sha512-p0bQukD8OEHxzY4T9OlANBbEFGnOnjo1CYi50HES7OD36UO2yPh6T+uOJKLtlg06eclxroipRCpQGMpeH8EJ/g==} + engines: {node: ^20.19.0 || >=22.12.0} + + '@oxc-project/types@0.122.0': + resolution: {integrity: sha512-oLAl5kBpV4w69UtFZ9xqcmTi+GENWOcPF7FCrczTiBbmC0ibXxCwyvZGbO39rCVEuLGAZM84DH0pUIyyv/YJzA==} + + '@oxfmt/binding-android-arm-eabi@0.42.0': + resolution: {integrity: sha512-dsqPTYsozeokRjlrt/b4E7Pj0z3eS3Eg74TWQuuKbjY4VttBmA88rB7d50Xrd+TZ986qdXCNeZRPEzZHAe+jow==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [android] + + '@oxfmt/binding-android-arm64@0.42.0': + resolution: {integrity: sha512-t+aAjHxcr5eOBphFHdg1ouQU9qmZZoRxnX7UOJSaTwSoKsb6TYezNKO0YbWytGXCECObRqNcUxPoPr0KaraAIg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [android] + + '@oxfmt/binding-darwin-arm64@0.42.0': + resolution: {integrity: sha512-ulpSEYMKg61C5bRMZinFHrKJYRoKGVbvMEXA5zM1puX3O9T6Q4XXDbft20yrDijpYWeuG59z3Nabt+npeTsM1A==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [darwin] + + '@oxfmt/binding-darwin-x64@0.42.0': + resolution: {integrity: sha512-ttxLKhQYPdFiM8I/Ri37cvqChE4Xa562nNOsZFcv1CKTVLeEozXjKuYClNvxkXmNlcF55nzM80P+CQkdFBu+uQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [darwin] + + '@oxfmt/binding-freebsd-x64@0.42.0': + resolution: {integrity: sha512-Og7QS3yI3tdIKYZ58SXik0rADxIk2jmd+/YvuHRyKULWpG4V2fR5V4hvKm624Mc0cQET35waPXiCQWvjQEjwYQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [freebsd] + + '@oxfmt/binding-linux-arm-gnueabihf@0.42.0': + resolution: {integrity: sha512-jwLOw/3CW4H6Vxcry4/buQHk7zm9Ne2YsidzTL1kpiMe4qqrRCwev3dkyWe2YkFmP+iZCQ7zku4KwjcLRoh8ew==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [linux] + + '@oxfmt/binding-linux-arm-musleabihf@0.42.0': + resolution: {integrity: sha512-XwXu2vkMtiq2h7tfvN+WA/9/5/1IoGAVCFPiiQUvcAuG3efR97KNcRGM8BetmbYouFotQ2bDal3yyjUx6IPsTg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [linux] + + '@oxfmt/binding-linux-arm64-gnu@0.42.0': + resolution: {integrity: sha512-ea7s/XUJoT7ENAtUQDudFe3nkSM3e3Qpz4nJFRdzO2wbgXEcjnchKLEsV3+t4ev3r8nWxIYr9NRjPWtnyIFJVA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [glibc] + + '@oxfmt/binding-linux-arm64-musl@0.42.0': + resolution: {integrity: sha512-+JA0YMlSdDqmacygGi2REp57c3fN+tzARD8nwsukx9pkCHK+6DkbAA9ojS4lNKsiBjIW8WWa0pBrBWhdZEqfuw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [musl] + + '@oxfmt/binding-linux-ppc64-gnu@0.42.0': + resolution: {integrity: sha512-VfnET0j4Y5mdfCzh5gBt0NK28lgn5DKx+8WgSMLYYeSooHhohdbzwAStLki9pNuGy51y4I7IoW8bqwAaCMiJQg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [ppc64] + os: [linux] + libc: [glibc] + + '@oxfmt/binding-linux-riscv64-gnu@0.42.0': + resolution: {integrity: sha512-gVlCbmBkB0fxBWbhBj9rcxezPydsQHf4MFKeHoTSPicOQ+8oGeTQgQ8EeesSybWeiFPVRx3bgdt4IJnH6nOjAA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [riscv64] + os: [linux] + libc: [glibc] + + '@oxfmt/binding-linux-riscv64-musl@0.42.0': + resolution: {integrity: sha512-zN5OfstL0avgt/IgvRu0zjQzVh/EPkcLzs33E9LMAzpqlLWiPWeMDZyMGFlSRGOdDjuNmlZBCgj0pFnK5u32TQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [riscv64] + os: [linux] + libc: [musl] + + '@oxfmt/binding-linux-s390x-gnu@0.42.0': + resolution: {integrity: sha512-9X6+H2L0qMc2sCAgO9HS03bkGLMKvOFjmEdchaFlany3vNZOjnVui//D8k/xZAtQv2vaCs1reD5KAgPoIU4msA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [s390x] + os: [linux] + libc: [glibc] + + '@oxfmt/binding-linux-x64-gnu@0.42.0': + resolution: {integrity: sha512-BajxJ6KQvMMdpXGPWhBGyjb2Jvx4uec0w+wi6TJZ6Tv7+MzPwe0pO8g5h1U0jyFgoaF7mDl6yKPW3ykWcbUJRw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [glibc] + + '@oxfmt/binding-linux-x64-musl@0.42.0': + resolution: {integrity: sha512-0wV284I6vc5f0AqAhgAbHU2935B4bVpncPoe5n/WzVZY/KnHgqxC8iSFGeSyLWEgstFboIcWkOPck7tqbdHkzA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [musl] + + '@oxfmt/binding-openharmony-arm64@0.42.0': + resolution: {integrity: sha512-p4BG6HpGnhfgHk1rzZfyR6zcWkE7iLrWxyehHfXUy4Qa5j3e0roglFOdP/Nj5cJJ58MA3isQ5dlfkW2nNEpolw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [openharmony] + + '@oxfmt/binding-win32-arm64-msvc@0.42.0': + resolution: {integrity: sha512-mn//WV60A+IetORDxYieYGAoQso4KnVRRjORDewMcod4irlRe0OSC7YPhhwaexYNPQz/GCFk+v9iUcZ2W22yxQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [win32] + + '@oxfmt/binding-win32-ia32-msvc@0.42.0': + resolution: {integrity: sha512-3gWltUrvuz4LPJXWivoAxZ28Of2O4N7OGuM5/X3ubPXCEV8hmgECLZzjz7UYvSDUS3grfdccQwmjynm+51EFpw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [ia32] + os: [win32] + + '@oxfmt/binding-win32-x64-msvc@0.42.0': + resolution: {integrity: sha512-Wg4TMAfQRL9J9AZevJ/ZNy3uyyDztDYQtGr4P8UyyzIhLhFrdSmz1J/9JT+rv0fiCDLaFOBQnj3f3K3+a5PzDQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [win32] + + '@oxlint-tsgolint/darwin-arm64@0.17.3': + resolution: {integrity: sha512-5aDl4mxXWs+Bj02pNrX6YY6v9KMZjLIytXoqolLEo0dfBNVeZUonZgJAa/w0aUmijwIRrBhxEzb42oLuUtfkGw==} + cpu: [arm64] + os: [darwin] + + '@oxlint-tsgolint/darwin-x64@0.17.3': + resolution: {integrity: sha512-gPBy4DS5ueCgXzko20XsNZzDe/Cxde056B+QuPLGvz05CGEAtmRfpImwnyY2lAXXjPL+SmnC/OYexu8zI12yHQ==} + cpu: [x64] + os: [darwin] + + '@oxlint-tsgolint/linux-arm64@0.17.3': + resolution: {integrity: sha512-+pkunvCfB6pB0G9qHVVXUao3nqzXQPo4O3DReIi+5nGa+bOU3J3Srgy+Zb8VyOL+WDsSMJ+U7+r09cKHWhz3hg==} + cpu: [arm64] + os: [linux] + + '@oxlint-tsgolint/linux-x64@0.17.3': + resolution: {integrity: sha512-/kW5oXtBThu4FjmgIBthdmMjWLzT3M1TEDQhxDu7hQU5xDeTd60CDXb2SSwKCbue9xu7MbiFoJu83LN0Z/d38g==} + cpu: [x64] + os: [linux] + + '@oxlint-tsgolint/win32-arm64@0.17.3': + resolution: {integrity: sha512-NMELRvbz4Ed4dxg8WiqZxtu3k4OJEp2B9KInZW+BMfqEqbwZdEJY83tbqz2hD1EjKO2akrqBQ0GpRUJEkd8kKw==} + cpu: [arm64] + os: [win32] + + '@oxlint-tsgolint/win32-x64@0.17.3': + resolution: {integrity: sha512-+pJ7r8J3SLPws5uoidVplZc8R/lpKyKPE6LoPGv9BME00Y1VjT6jWGx/dtUN8PWvcu3iTC6k+8u3ojFSJNmWTg==} + cpu: [x64] + os: [win32] + + '@oxlint/binding-android-arm-eabi@1.57.0': + resolution: {integrity: sha512-C7EiyfAJG4B70496eV543nKiq5cH0o/xIh/ufbjQz3SIvHhlDDsyn+mRFh+aW8KskTyUpyH2LGWL8p2oN6bl1A==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [android] + + '@oxlint/binding-android-arm64@1.57.0': + resolution: {integrity: sha512-9i80AresjZ/FZf5xK8tKFbhQnijD4s1eOZw6/FHUwD59HEZbVLRc2C88ADYJfLZrF5XofWDiRX/Ja9KefCLy7w==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [android] + + '@oxlint/binding-darwin-arm64@1.57.0': + resolution: {integrity: sha512-0eUfhRz5L2yKa9I8k3qpyl37XK3oBS5BvrgdVIx599WZK63P8sMbg+0s4IuxmIiZuBK68Ek+Z+gcKgeYf0otsg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [darwin] + + '@oxlint/binding-darwin-x64@1.57.0': + resolution: {integrity: sha512-UvrSuzBaYOue+QMAcuDITe0k/Vhj6KZGjfnI6x+NkxBTke/VoM7ZisaxgNY0LWuBkTnd1OmeQfEQdQ48fRjkQg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [darwin] + + '@oxlint/binding-freebsd-x64@1.57.0': + resolution: {integrity: sha512-wtQq0dCoiw4bUwlsNVDJJ3pxJA218fOezpgtLKrbQqUtQJcM9yP8z+I9fu14aHg0uyAxIY+99toL6uBa2r7nxA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [freebsd] + + '@oxlint/binding-linux-arm-gnueabihf@1.57.0': + resolution: {integrity: sha512-qxFWl2BBBFcT4djKa+OtMdnLgoHEJXpqjyGwz8OhW35ImoCwR5qtAGqApNYce5260FQqoAHW8S8eZTjiX67Tsg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [linux] + + '@oxlint/binding-linux-arm-musleabihf@1.57.0': + resolution: {integrity: sha512-SQoIsBU7J0bDW15/f0/RvxHfY3Y0+eB/caKBQtNFbuerTiA6JCYx9P1MrrFTwY2dTm/lMgTSgskvCEYk2AtG/Q==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [linux] + + '@oxlint/binding-linux-arm64-gnu@1.57.0': + resolution: {integrity: sha512-jqxYd1W6WMeozsCmqe9Rzbu3SRrGTyGDAipRlRggetyYbUksJqJKvUNTQtZR/KFoJPb+grnSm5SHhdWrywv3RQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [glibc] + + '@oxlint/binding-linux-arm64-musl@1.57.0': + resolution: {integrity: sha512-i66WyEPVEvq9bxRUCJ/MP5EBfnTDN3nhwEdFZFTO5MmLLvzngfWEG3NSdXQzTT3vk5B9i6C2XSIYBh+aG6uqyg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [musl] + + '@oxlint/binding-linux-ppc64-gnu@1.57.0': + resolution: {integrity: sha512-oMZDCwz4NobclZU3pH+V1/upVlJZiZvne4jQP+zhJwt+lmio4XXr4qG47CehvrW1Lx2YZiIHuxM2D4YpkG3KVA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [ppc64] + os: [linux] + libc: [glibc] + + '@oxlint/binding-linux-riscv64-gnu@1.57.0': + resolution: {integrity: sha512-uoBnjJ3MMEBbfnWC1jSFr7/nSCkcQYa72NYoNtLl1imshDnWSolYCjzb8LVCwYCCfLJXD+0gBLD7fyC14c0+0g==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [riscv64] + os: [linux] + libc: [glibc] + + '@oxlint/binding-linux-riscv64-musl@1.57.0': + resolution: {integrity: sha512-BdrwD7haPZ8a9KrZhKJRSj6jwCor+Z8tHFZ3PT89Y3Jq5v3LfMfEePeAmD0LOTWpiTmzSzdmyw9ijneapiVHKQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [riscv64] + os: [linux] + libc: [musl] + + '@oxlint/binding-linux-s390x-gnu@1.57.0': + resolution: {integrity: sha512-BNs+7ZNsRstVg2tpNxAXfMX/Iv5oZh204dVyb8Z37+/gCh+yZqNTlg6YwCLIMPSk5wLWIGOaQjT0GUOahKYImw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [s390x] + os: [linux] + libc: [glibc] + + '@oxlint/binding-linux-x64-gnu@1.57.0': + resolution: {integrity: sha512-AghS18w+XcENcAX0+BQGLiqjpqpaxKJa4cWWP0OWNLacs27vHBxu7TYkv9LUSGe5w8lOJHeMxcYfZNOAPqw2bg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [glibc] + + '@oxlint/binding-linux-x64-musl@1.57.0': + resolution: {integrity: sha512-E/FV3GB8phu/Rpkhz5T96hAiJlGzn91qX5yj5gU754P5cmVGXY1Jw/VSjDSlZBCY3VHjsVLdzgdkJaomEmcNOg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [musl] + + '@oxlint/binding-openharmony-arm64@1.57.0': + resolution: {integrity: sha512-xvZ2yZt0nUVfU14iuGv3V25jpr9pov5N0Wr28RXnHFxHCRxNDMtYPHV61gGLhN9IlXM96gI4pyYpLSJC5ClLCQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [openharmony] + + '@oxlint/binding-win32-arm64-msvc@1.57.0': + resolution: {integrity: sha512-Z4D8Pd0AyHBKeazhdIXeUUy5sIS3Mo0veOlzlDECg6PhRRKgEsBJCCV1n+keUZtQ04OP+i7+itS3kOykUyNhDg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [win32] + + '@oxlint/binding-win32-ia32-msvc@1.57.0': + resolution: {integrity: sha512-StOZ9nFMVKvevicbQfql6Pouu9pgbeQnu60Fvhz2S6yfMaii+wnueLnqQ5I1JPgNF0Syew4voBlAaHD13wH6tw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [ia32] + os: [win32] + + '@oxlint/binding-win32-x64-msvc@1.57.0': + resolution: {integrity: sha512-6PuxhYgth8TuW0+ABPOIkGdBYw+qYGxgIdXPHSVpiCDm+hqTTWCmC739St1Xni0DJBt8HnSHTG67i1y6gr8qrA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [win32] + + '@playwright/test@1.51.1': + resolution: {integrity: sha512-nM+kEaTSAoVlXmMPH10017vn3FSiFqr/bh4fKg9vmAdMfd9SDqRZNvPSiAHADc/itWak+qPvMPZQOPwCBW7k7Q==} + engines: {node: '>=18'} + hasBin: true + + '@polka/url@1.0.0-next.29': + resolution: {integrity: sha512-wwQAWhWSuHaag8c4q/KN/vCoeOJYshAIvMQwD4GpSb3OiZklFfvAgmj0VCBBImRpuF/aFgIRzllXlVX93Jevww==} + + '@rolldown/binding-android-arm64@1.0.0-rc.12': + resolution: {integrity: sha512-pv1y2Fv0JybcykuiiD3qBOBdz6RteYojRFY1d+b95WVuzx211CRh+ytI/+9iVyWQ6koTh5dawe4S/yRfOFjgaA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [android] + + '@rolldown/binding-darwin-arm64@1.0.0-rc.12': + resolution: {integrity: sha512-cFYr6zTG/3PXXF3pUO+umXxt1wkRK/0AYT8lDwuqvRC+LuKYWSAQAQZjCWDQpAH172ZV6ieYrNnFzVVcnSflAg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [darwin] + + '@rolldown/binding-darwin-x64@1.0.0-rc.12': + resolution: {integrity: sha512-ZCsYknnHzeXYps0lGBz8JrF37GpE9bFVefrlmDrAQhOEi4IOIlcoU1+FwHEtyXGx2VkYAvhu7dyBf75EJQffBw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [darwin] + + '@rolldown/binding-freebsd-x64@1.0.0-rc.12': + resolution: {integrity: sha512-dMLeprcVsyJsKolRXyoTH3NL6qtsT0Y2xeuEA8WQJquWFXkEC4bcu1rLZZSnZRMtAqwtrF/Ib9Ddtpa/Gkge9Q==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [freebsd] + + '@rolldown/binding-linux-arm-gnueabihf@1.0.0-rc.12': + resolution: {integrity: sha512-YqWjAgGC/9M1lz3GR1r1rP79nMgo3mQiiA+Hfo+pvKFK1fAJ1bCi0ZQVh8noOqNacuY1qIcfyVfP6HoyBRZ85Q==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [linux] + + '@rolldown/binding-linux-arm64-gnu@1.0.0-rc.12': + resolution: {integrity: sha512-/I5AS4cIroLpslsmzXfwbe5OmWvSsrFuEw3mwvbQ1kDxJ822hFHIx+vsN/TAzNVyepI/j/GSzrtCIwQPeKCLIg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [glibc] + + '@rolldown/binding-linux-arm64-musl@1.0.0-rc.12': + resolution: {integrity: sha512-V6/wZztnBqlx5hJQqNWwFdxIKN0m38p8Jas+VoSfgH54HSj9tKTt1dZvG6JRHcjh6D7TvrJPWFGaY9UBVOaWPw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [musl] + + '@rolldown/binding-linux-ppc64-gnu@1.0.0-rc.12': + resolution: {integrity: sha512-AP3E9BpcUYliZCxa3w5Kwj9OtEVDYK6sVoUzy4vTOJsjPOgdaJZKFmN4oOlX0Wp0RPV2ETfmIra9x1xuayFB7g==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [ppc64] + os: [linux] + libc: [glibc] + + '@rolldown/binding-linux-s390x-gnu@1.0.0-rc.12': + resolution: {integrity: sha512-nWwpvUSPkoFmZo0kQazZYOrT7J5DGOJ/+QHHzjvNlooDZED8oH82Yg67HvehPPLAg5fUff7TfWFHQS8IV1n3og==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [s390x] + os: [linux] + libc: [glibc] + + '@rolldown/binding-linux-x64-gnu@1.0.0-rc.12': + resolution: {integrity: sha512-RNrafz5bcwRy+O9e6P8Z/OCAJW/A+qtBczIqVYwTs14pf4iV1/+eKEjdOUta93q2TsT/FI0XYDP3TCky38LMAg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [glibc] + + '@rolldown/binding-linux-x64-musl@1.0.0-rc.12': + resolution: {integrity: sha512-Jpw/0iwoKWx3LJ2rc1yjFrj+T7iHZn2JDg1Yny1ma0luviFS4mhAIcd1LFNxK3EYu3DHWCps0ydXQ5i/rrJ2ig==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [musl] + + '@rolldown/binding-openharmony-arm64@1.0.0-rc.12': + resolution: {integrity: sha512-vRugONE4yMfVn0+7lUKdKvN4D5YusEiPilaoO2sgUWpCvrncvWgPMzK00ZFFJuiPgLwgFNP5eSiUlv2tfc+lpA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [openharmony] + + '@rolldown/binding-wasm32-wasi@1.0.0-rc.12': + resolution: {integrity: sha512-ykGiLr/6kkiHc0XnBfmFJuCjr5ZYKKofkx+chJWDjitX+KsJuAmrzWhwyOMSHzPhzOHOy7u9HlFoa5MoAOJ/Zg==} + engines: {node: '>=14.0.0'} + cpu: [wasm32] + + '@rolldown/binding-win32-arm64-msvc@1.0.0-rc.12': + resolution: {integrity: sha512-5eOND4duWkwx1AzCxadcOrNeighiLwMInEADT0YM7xeEOOFcovWZCq8dadXgcRHSf3Ulh1kFo/qvzoFiCLOL1Q==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [win32] + + '@rolldown/binding-win32-x64-msvc@1.0.0-rc.12': + resolution: {integrity: sha512-PyqoipaswDLAZtot351MLhrlrh6lcZPo2LSYE+VDxbVk24LVKAGOuE4hb8xZQmrPAuEtTZW8E6D2zc5EUZX4Lw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [win32] + + '@rolldown/pluginutils@1.0.0-rc.12': + resolution: {integrity: sha512-HHMwmarRKvoFsJorqYlFeFRzXZqCt2ETQlEDOb9aqssrnVBB1/+xgTGtuTrIk5vzLNX1MjMtTf7W9z3tsSbrxw==} + + '@standard-schema/spec@1.1.0': + resolution: {integrity: sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==} + + '@teppeis/multimaps@3.0.0': + resolution: {integrity: sha512-ID7fosbc50TbT0MK0EG12O+gAP3W3Aa/Pz4DaTtQtEvlc9Odaqi0de+xuZ7Li2GtK4HzEX7IuRWS/JmZLksR3Q==} + engines: {node: '>=14'} + + '@tybys/wasm-util@0.10.1': + resolution: {integrity: sha512-9tTaPJLSiejZKx+Bmog4uSubteqTvFrVrURwkmHixBo0G4seD0zUxp98E1DzUBJxLQ3NPwXrGKDiVjwx/DpPsg==} + + '@types/chai@5.2.3': + resolution: {integrity: sha512-Mw558oeA9fFbv65/y4mHtXDs9bPnFMZAL/jxdPFUpOHHIXX91mcgEHbS5Lahr+pwZFR8A7GQleRWeI6cGFC2UA==} + + '@types/deep-eql@4.0.2': + resolution: {integrity: sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==} + + '@types/node@25.5.0': + resolution: {integrity: sha512-jp2P3tQMSxWugkCUKLRPVUpGaL5MVFwF8RDuSRztfwgN1wmqJeMSbKlnEtQqU8UrhTmzEmZdu2I6v2dpp7XIxw==} + + '@types/normalize-package-data@2.4.4': + resolution: {integrity: sha512-37i+OaWTh9qeK4LSHPsyRC7NahnGotNuZvjLSgcPzblpHB3rrCJxAOgI5gCdKm7coonsaX1Of0ILiTcnZjbfxA==} + + '@voidzero-dev/vite-plus-core@0.1.14': + resolution: {integrity: sha512-CCWzdkfW0fo0cQNlIsYp5fOuH2IwKuPZEb2UY2Z8gXcp5pG74A82H2Pthj0heAuvYTAnfT7kEC6zM+RbiBgQbg==} + engines: {node: ^20.19.0 || >=22.12.0} + peerDependencies: + '@arethetypeswrong/core': ^0.18.1 + '@tsdown/css': 0.21.4 + '@tsdown/exe': 0.21.4 + '@types/node': ^20.19.0 || >=22.12.0 + '@vitejs/devtools': ^0.1.0 + esbuild: ^0.27.0 + jiti: '>=1.21.0' + less: ^4.0.0 + publint: ^0.3.0 + sass: ^1.70.0 + sass-embedded: ^1.70.0 + stylus: '>=0.54.8' + sugarss: ^5.0.0 + terser: ^5.16.0 + tsx: ^4.8.1 + typescript: ^5.0.0 + unplugin-unused: ^0.5.0 + yaml: ^2.4.2 + peerDependenciesMeta: + '@arethetypeswrong/core': + optional: true + '@tsdown/css': + optional: true + '@tsdown/exe': + optional: true + '@types/node': + optional: true + '@vitejs/devtools': + optional: true + esbuild: + optional: true + jiti: + optional: true + less: + optional: true + publint: + optional: true + sass: + optional: true + sass-embedded: + optional: true + stylus: + optional: true + sugarss: + optional: true + terser: + optional: true + tsx: + optional: true + typescript: + optional: true + unplugin-unused: + optional: true + yaml: + optional: true + + '@voidzero-dev/vite-plus-darwin-arm64@0.1.14': + resolution: {integrity: sha512-q2ESUSbapwsxVRe/KevKATahNRraoX5nti3HT9S3266OHT5sMroBY14jaxTv74ekjQc9E6EPhyLGQWuWQuuBRw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [darwin] + + '@voidzero-dev/vite-plus-darwin-x64@0.1.14': + resolution: {integrity: sha512-UpcDZc9G99E/4HDRoobvYHxMvFOG5uv3RwEcq0HF70u4DsnEMl1z8RaJLeWV7a09LGwj9Q+YWC3Z4INWnTLs8g==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [darwin] + + '@voidzero-dev/vite-plus-linux-arm64-gnu@0.1.14': + resolution: {integrity: sha512-GIjn35RABUEDB9gHD26nRq7T72Te+Qy2+NIzogwEaUE728PvPkatF5gMCeF4sigCoc8c4qxDwsG+A2A2LYGnDg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [glibc] + + '@voidzero-dev/vite-plus-linux-arm64-musl@0.1.14': + resolution: {integrity: sha512-qo2RToGirG0XCcxZ2AEOuonLM256z6dNbJzDDIo5gWYA+cIKigFQJbkPyr25zsT1tsP2aY0OTxt2038XbVlRkQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [musl] + + '@voidzero-dev/vite-plus-linux-x64-gnu@0.1.14': + resolution: {integrity: sha512-BsMWKZfdfGcYLxxLyaePpg6NW54xqzzcfq8sFUwKfwby0kgOKQ4WymUXyBvO9nnBb0ZPsJQrV0sx+Onac/LTaw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [glibc] + + '@voidzero-dev/vite-plus-linux-x64-musl@0.1.14': + resolution: {integrity: sha512-mOrEpj7ntW9RopGbcOYG/L0pOs0qHzUG4Vz7NXbuf4dbOSlY4JjyoMOIWxjKQORQht02Hzuf8YrMGNwa6AjVSQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [musl] + + '@voidzero-dev/vite-plus-test@0.1.14': + resolution: {integrity: sha512-rjF+qpYD+5+THOJZ3gbE3+cxsk5sW7nJ0ODK7y6ZKeS4amREUMedEDYykzKBwR7OZDC/WwE90A0iLWCr6qAXhA==} + engines: {node: ^20.0.0 || ^22.0.0 || >=24.0.0} + peerDependencies: + '@edge-runtime/vm': '*' + '@opentelemetry/api': ^1.9.0 + '@types/node': ^20.0.0 || ^22.0.0 || >=24.0.0 + '@vitest/ui': 4.1.1 + happy-dom: '*' + jsdom: '*' + vite: ^6.0.0 || ^7.0.0 || ^8.0.0 + peerDependenciesMeta: + '@edge-runtime/vm': + optional: true + '@opentelemetry/api': + optional: true + '@types/node': + optional: true + '@vitest/ui': + optional: true + happy-dom: + optional: true + jsdom: + optional: true + + '@voidzero-dev/vite-plus-win32-arm64-msvc@0.1.14': + resolution: {integrity: sha512-7iC+Ig+8D/zACy0IJf7w/vQ7duTjux9Ttmm3KOBdVWH4dl3JihydA7+SQVMhz71a4WiqJ6nPidoG8D6hUP4MVQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [win32] + + '@voidzero-dev/vite-plus-win32-x64-msvc@0.1.14': + resolution: {integrity: sha512-yRJ/8yAYFluNHx0Ej6Kevx65MIeM3wFKklnxosVZRlz2ZRL1Ea1Qh3tWATr3Ipk1ciRxBv8KJgp6zXqjxtZSoQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [win32] + + ansi-regex@4.1.1: + resolution: {integrity: sha512-ILlv4k/3f6vfQ4OoP2AGvirOktlQ98ZEL1k9FaQjxa3L1abBgbuTDAdPOpvbGncC0BTVQrl+OM8xZGK6tWXt7g==} + engines: {node: '>=6'} + + ansi-regex@5.0.1: + resolution: {integrity: sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==} + engines: {node: '>=8'} + + ansi-styles@4.3.0: + resolution: {integrity: sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==} + engines: {node: '>=8'} + + ansi-styles@5.2.0: + resolution: {integrity: sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==} + engines: {node: '>=10'} + + any-promise@1.3.0: + resolution: {integrity: sha512-7UvmKalWRt1wgjL1RrGxoSJW/0QZFIegpeGvZG9kjp8vrRu55XTHbwnqq2GpXm9uLbcuhxm3IqX9OB4MZR1b2A==} + + assertion-error-formatter@3.0.0: + resolution: {integrity: sha512-6YyAVLrEze0kQ7CmJfUgrLHb+Y7XghmL2Ie7ijVa2Y9ynP3LV+VDiwFk62Dn0qtqbmY0BT0ss6p1xxpiF2PYbQ==} + + assertion-error@2.0.1: + resolution: {integrity: sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==} + engines: {node: '>=12'} + + balanced-match@4.0.4: + resolution: {integrity: sha512-BLrgEcRTwX2o6gGxGOCNyMvGSp35YofuYzw9h1IMTRmKqttAZZVU67bdb9Pr2vUHA8+j3i2tJfjO6C6+4myGTA==} + engines: {node: 18 || 20 || >=22} + + brace-expansion@5.0.5: + resolution: {integrity: sha512-VZznLgtwhn+Mact9tfiwx64fA9erHH/MCXEUfB/0bX/6Fz6ny5EGTXYltMocqg4xFAQZtnO3DHWWXi8RiuN7cQ==} + engines: {node: 18 || 20 || >=22} + + buffer-from@1.1.2: + resolution: {integrity: sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==} + + cac@7.0.0: + resolution: {integrity: sha512-tixWYgm5ZoOD+3g6UTea91eow5z6AAHaho3g0V9CNSNb45gM8SmflpAc+GRd1InC4AqN/07Unrgp56Y94N9hJQ==} + engines: {node: '>=20.19.0'} + + capital-case@1.0.4: + resolution: {integrity: sha512-ds37W8CytHgwnhGGTi88pcPyR15qoNkOpYwmMMfnWqqWgESapLqvDx6huFjQ5vqWSn2Z06173XNA7LtMOeUh1A==} + + chalk@4.1.2: + resolution: {integrity: sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==} + engines: {node: '>=10'} + + class-transformer@0.5.1: + resolution: {integrity: sha512-SQa1Ws6hUbfC98vKGxZH3KFY0Y1lm5Zm0SY8XX9zbK7FJCyVEac3ATW0RIpwzW+oOfmHE5PMPufDG9hCfoEOMw==} + + cli-table3@0.6.5: + resolution: {integrity: sha512-+W/5efTR7y5HRD7gACw9yQjqMVvEMLBHmboM/kPWam+H+Hmyrgjh6YncVKK122YZkXrLudzTuAukUw9FnMf7IQ==} + engines: {node: 10.* || >= 12.*} + + color-convert@2.0.1: + resolution: {integrity: sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==} + engines: {node: '>=7.0.0'} + + color-name@1.1.4: + resolution: {integrity: sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==} + + commander@14.0.0: + resolution: {integrity: sha512-2uM9rYjPvyq39NwLRqaiLtWHyDC1FvryJDa2ATTVims5YAS4PupsEQsDvP14FqhFr0P49CYDugi59xaxJlTXRA==} + engines: {node: '>=20'} + + commander@14.0.2: + resolution: {integrity: sha512-TywoWNNRbhoD0BXs1P3ZEScW8W5iKrnbithIl0YH+uCmBd0QpPOA8yc82DS3BIE5Ma6FnBVUsJ7wVUDz4dvOWQ==} + engines: {node: '>=20'} + + commander@14.0.3: + resolution: {integrity: sha512-H+y0Jo/T1RZ9qPP4Eh1pkcQcLRglraJaSLoyOtHxu6AapkjWVCy2Sit1QQ4x3Dng8qDlSsZEet7g5Pq06MvTgw==} + engines: {node: '>=20'} + + cross-spawn@7.0.6: + resolution: {integrity: sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==} + engines: {node: '>= 8'} + + debug@4.4.3: + resolution: {integrity: sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==} + engines: {node: '>=6.0'} + peerDependencies: + supports-color: '*' + peerDependenciesMeta: + supports-color: + optional: true + + detect-libc@2.1.2: + resolution: {integrity: sha512-Btj2BOOO83o3WyH59e8MgXsxEQVcarkUOpEYrubB0urwnN10yQ364rsiByU11nZlqWYZm05i/of7io4mzihBtQ==} + engines: {node: '>=8'} + + diff@4.0.4: + resolution: {integrity: sha512-X07nttJQkwkfKfvTPG/KSnE2OMdcUCao6+eXF3wmnIQRn2aPAHH3VxDbDOdegkd6JbPsXqShpvEOHfAT+nCNwQ==} + engines: {node: '>=0.3.1'} + + emoji-regex@8.0.0: + resolution: {integrity: sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==} + + error-stack-parser@2.1.4: + resolution: {integrity: sha512-Sk5V6wVazPhq5MhpO+AUxJn5x7XSXGl1R93Vn7i+zS15KDVxQijejNCrz8340/2bgLBjR9GtEG8ZVKONDjcqGQ==} + + es-module-lexer@1.7.0: + resolution: {integrity: sha512-jEQoCwk8hyb2AZziIOLhDqpm5+2ww5uIE6lkO/6jcOCusfk6LhMHpXXfBLXTZ7Ydyt0j4VoUQv6uGNYbdW+kBA==} + + esbuild@0.27.4: + resolution: {integrity: sha512-Rq4vbHnYkK5fws5NF7MYTU68FPRE1ajX7heQ/8QXXWqNgqqJ/GkmmyxIzUnf2Sr/bakf8l54716CcMGHYhMrrQ==} + engines: {node: '>=18'} + hasBin: true + + escape-string-regexp@1.0.5: + resolution: {integrity: sha512-vbRorB5FUQWvla16U8R/qgaFIya2qGzwDrNmCZuYKrbdSUMG6I1ZCGQRefkRVhuOkIGVne7BQ35DSfo1qvJqFg==} + engines: {node: '>=0.8.0'} + + fdir@6.5.0: + resolution: {integrity: sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg==} + engines: {node: '>=12.0.0'} + peerDependencies: + picomatch: ^3 || ^4 + peerDependenciesMeta: + picomatch: + optional: true + + figures@3.2.0: + resolution: {integrity: sha512-yaduQFRKLXYOGgEn6AZau90j3ggSOyiqXU0F9JZfeXYhNa+Jk4X+s45A2zg5jns87GAFa34BBm2kXw4XpNcbdg==} + engines: {node: '>=8'} + + find-up-simple@1.0.1: + resolution: {integrity: sha512-afd4O7zpqHeRyg4PfDQsXmlDe2PfdHtJt6Akt8jOWaApLOZk5JXs6VMR29lz03pRe9mpykrRCYIYxaJYcfpncQ==} + engines: {node: '>=18'} + + fsevents@2.3.2: + resolution: {integrity: sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==} + engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} + os: [darwin] + + fsevents@2.3.3: + resolution: {integrity: sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==} + engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} + os: [darwin] + + get-tsconfig@4.13.7: + resolution: {integrity: sha512-7tN6rFgBlMgpBML5j8typ92BKFi2sFQvIdpAqLA2beia5avZDrMs0FLZiM5etShWq5irVyGcGMEA1jcDaK7A/Q==} + + glob@13.0.6: + resolution: {integrity: sha512-Wjlyrolmm8uDpm/ogGyXZXb1Z+Ca2B8NbJwqBVg0axK9GbBeoS7yGV6vjXnYdGm6X53iehEuxxbyiKp8QmN4Vw==} + engines: {node: 18 || 20 || >=22} + + global-dirs@3.0.1: + resolution: {integrity: sha512-NBcGGFbBA9s1VzD41QXDG+3++t9Mn5t1FpLdhESY6oKY4gYTFpX4wO3sqGUa0Srjtbfj3szX0RnemmrVRUdULA==} + engines: {node: '>=10'} + + has-ansi@4.0.1: + resolution: {integrity: sha512-Qr4RtTm30xvEdqUXbSBVWDu+PrTokJOwe/FU+VdfJPk+MXAPoeOzKpRyrDTnZIJwAkQ4oBLTU53nu0HrkF/Z2A==} + engines: {node: '>=8'} + + has-flag@4.0.0: + resolution: {integrity: sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==} + engines: {node: '>=8'} + + hosted-git-info@9.0.2: + resolution: {integrity: sha512-M422h7o/BR3rmCQ8UHi7cyyMqKltdP9Uo+J2fXK+RSAY+wTcKOIRyhTuKv4qn+DJf3g+PL890AzId5KZpX+CBg==} + engines: {node: ^20.17.0 || >=22.9.0} + + indent-string@4.0.0: + resolution: {integrity: sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg==} + engines: {node: '>=8'} + + index-to-position@1.2.0: + resolution: {integrity: sha512-Yg7+ztRkqslMAS2iFaU+Oa4KTSidr63OsFGlOrJoW981kIYO3CGCS3wA95P1mUi/IVSJkn0D479KTJpVpvFNuw==} + engines: {node: '>=18'} + + ini@2.0.0: + resolution: {integrity: sha512-7PnF4oN3CvZF23ADhA5wRaYEQpJ8qygSkbtTXWBeXWXmEVRXK+1ITciHWwHhsjv1TmW0MgacIv6hEi5pX5NQdA==} + engines: {node: '>=10'} + + is-fullwidth-code-point@3.0.0: + resolution: {integrity: sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==} + engines: {node: '>=8'} + + is-installed-globally@0.4.0: + resolution: {integrity: sha512-iwGqO3J21aaSkC7jWnHP/difazwS7SFeIqxv6wEtLU8Y5KlzFTjyqcSIT0d8s4+dDhKytsk9PJZ2BkS5eZwQRQ==} + engines: {node: '>=10'} + + is-path-inside@3.0.3: + resolution: {integrity: sha512-Fd4gABb+ycGAmKou8eMftCupSir5lRxqf4aD/vd0cD2qc4HL07OjCeuHMr8Ro4CoMaeCKDB0/ECBOVWjTwUvPQ==} + engines: {node: '>=8'} + + is-stream@2.0.1: + resolution: {integrity: sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==} + engines: {node: '>=8'} + + isexe@2.0.0: + resolution: {integrity: sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==} + + js-tokens@4.0.0: + resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==} + + knuth-shuffle-seeded@1.0.6: + resolution: {integrity: sha512-9pFH0SplrfyKyojCLxZfMcvkhf5hH0d+UwR9nTVJ/DDQJGuzcXjTwB7TP7sDfehSudlGGaOLblmEWqv04ERVWg==} + + lightningcss-android-arm64@1.32.0: + resolution: {integrity: sha512-YK7/ClTt4kAK0vo6w3X+Pnm0D2cf2vPHbhOXdoNti1Ga0al1P4TBZhwjATvjNwLEBCnKvjJc2jQgHXH0NEwlAg==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [android] + + lightningcss-darwin-arm64@1.32.0: + resolution: {integrity: sha512-RzeG9Ju5bag2Bv1/lwlVJvBE3q6TtXskdZLLCyfg5pt+HLz9BqlICO7LZM7VHNTTn/5PRhHFBSjk5lc4cmscPQ==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [darwin] + + lightningcss-darwin-x64@1.32.0: + resolution: {integrity: sha512-U+QsBp2m/s2wqpUYT/6wnlagdZbtZdndSmut/NJqlCcMLTWp5muCrID+K5UJ6jqD2BFshejCYXniPDbNh73V8w==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [darwin] + + lightningcss-freebsd-x64@1.32.0: + resolution: {integrity: sha512-JCTigedEksZk3tHTTthnMdVfGf61Fky8Ji2E4YjUTEQX14xiy/lTzXnu1vwiZe3bYe0q+SpsSH/CTeDXK6WHig==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [freebsd] + + lightningcss-linux-arm-gnueabihf@1.32.0: + resolution: {integrity: sha512-x6rnnpRa2GL0zQOkt6rts3YDPzduLpWvwAF6EMhXFVZXD4tPrBkEFqzGowzCsIWsPjqSK+tyNEODUBXeeVHSkw==} + engines: {node: '>= 12.0.0'} + cpu: [arm] + os: [linux] + + lightningcss-linux-arm64-gnu@1.32.0: + resolution: {integrity: sha512-0nnMyoyOLRJXfbMOilaSRcLH3Jw5z9HDNGfT/gwCPgaDjnx0i8w7vBzFLFR1f6CMLKF8gVbebmkUN3fa/kQJpQ==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [linux] + libc: [glibc] + + lightningcss-linux-arm64-musl@1.32.0: + resolution: {integrity: sha512-UpQkoenr4UJEzgVIYpI80lDFvRmPVg6oqboNHfoH4CQIfNA+HOrZ7Mo7KZP02dC6LjghPQJeBsvXhJod/wnIBg==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [linux] + libc: [musl] + + lightningcss-linux-x64-gnu@1.32.0: + resolution: {integrity: sha512-V7Qr52IhZmdKPVr+Vtw8o+WLsQJYCTd8loIfpDaMRWGUZfBOYEJeyJIkqGIDMZPwPx24pUMfwSxxI8phr/MbOA==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [linux] + libc: [glibc] + + lightningcss-linux-x64-musl@1.32.0: + resolution: {integrity: sha512-bYcLp+Vb0awsiXg/80uCRezCYHNg1/l3mt0gzHnWV9XP1W5sKa5/TCdGWaR/zBM2PeF/HbsQv/j2URNOiVuxWg==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [linux] + libc: [musl] + + lightningcss-win32-arm64-msvc@1.32.0: + resolution: {integrity: sha512-8SbC8BR40pS6baCM8sbtYDSwEVQd4JlFTOlaD3gWGHfThTcABnNDBda6eTZeqbofalIJhFx0qKzgHJmcPTnGdw==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [win32] + + lightningcss-win32-x64-msvc@1.32.0: + resolution: {integrity: sha512-Amq9B/SoZYdDi1kFrojnoqPLxYhQ4Wo5XiL8EVJrVsB8ARoC1PWW6VGtT0WKCemjy8aC+louJnjS7U18x3b06Q==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [win32] + + lightningcss@1.32.0: + resolution: {integrity: sha512-NXYBzinNrblfraPGyrbPoD19C1h9lfI/1mzgWYvXUTe414Gz/X1FD2XBZSZM7rRTrMA8JL3OtAaGifrIKhQ5yQ==} + engines: {node: '>= 12.0.0'} + + lodash.merge@4.6.2: + resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==} + + lodash.mergewith@4.6.2: + resolution: {integrity: sha512-GK3g5RPZWTRSeLSpgP8Xhra+pnjBC56q9FZYe1d5RN3TJ35dbkGy3YqBSMbyCrlbi+CM9Z3Jk5yTL7RCsqboyQ==} + + lodash.sortby@4.7.0: + resolution: {integrity: sha512-HDWXG8isMntAyRF5vZ7xKuEvOhT4AhlRt/3czTSjvGUxjYCBVRQY48ViDHyfYz9VIoBkW4TMGQNapx+l3RUwdA==} + + lower-case@2.0.2: + resolution: {integrity: sha512-7fm3l3NAF9WfN6W3JOmf5drwpVqX78JtoGJ3A6W0a6ZnldM41w2fV5D490psKFTpMds8TJse/eHLFFsNHHjHgg==} + + lru-cache@11.2.7: + resolution: {integrity: sha512-aY/R+aEsRelme17KGQa/1ZSIpLpNYYrhcrepKTZgE+W3WM16YMCaPwOHLHsmopZHELU0Ojin1lPVxKR0MihncA==} + engines: {node: 20 || >=22} + + luxon@3.7.2: + resolution: {integrity: sha512-vtEhXh/gNjI9Yg1u4jX/0YVPMvxzHuGgCm6tC5kZyb08yjGWGnqAjGJvcXbqQR2P3MyMEFnRbpcdFS6PBcLqew==} + engines: {node: '>=12'} + + mime@3.0.0: + resolution: {integrity: sha512-jSCU7/VB1loIWBZe14aEYHU/+1UMEHoaO7qxCOVJOw9GgH72VAWppxNcjU+x9a2k3GSIBXNKxXQFqRvvZ7vr3A==} + engines: {node: '>=10.0.0'} + hasBin: true + + minimatch@10.2.4: + resolution: {integrity: sha512-oRjTw/97aTBN0RHbYCdtF1MQfvusSIBQM0IZEgzl6426+8jSC0nF1a/GmnVLpfB9yyr6g6FTqWqiZVbxrtaCIg==} + engines: {node: 18 || 20 || >=22} + + minipass@7.1.3: + resolution: {integrity: sha512-tEBHqDnIoM/1rXME1zgka9g6Q2lcoCkxHLuc7ODJ5BxbP5d4c2Z5cGgtXAku59200Cx7diuHTOYfSBD8n6mm8A==} + engines: {node: '>=16 || 14 >=14.17'} + + mkdirp@3.0.1: + resolution: {integrity: sha512-+NsyUUAZDmo6YVHzL/stxSu3t9YS1iljliy3BSDrXJ/dkn1KYdmtZODGGjLcc9XLgVVpH4KshHB8XmZgMhaBXg==} + engines: {node: '>=10'} + hasBin: true + + mrmime@2.0.1: + resolution: {integrity: sha512-Y3wQdFg2Va6etvQ5I82yUhGdsKrcYox6p7FfL1LbK2J4V01F9TGlepTIhnK24t7koZibmg82KGglhA1XK5IsLQ==} + engines: {node: '>=10'} + + ms@2.1.3: + resolution: {integrity: sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==} + + mz@2.7.0: + resolution: {integrity: sha512-z81GNO7nnYMEhrGh9LeymoE4+Yr0Wn5McHIZMK5cfQCl+NDX08sCZgUc9/6MHni9IWuFLm1Z3HTCXu2z9fN62Q==} + + nanoid@3.3.11: + resolution: {integrity: sha512-N8SpfPUnUp1bK+PMYW8qSWdl9U+wwNWI4QKxOYDy9JAro3WMX7p2OeVRF9v+347pnakNevPmiHhNmZ2HbFA76w==} + engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} + hasBin: true + + no-case@3.0.4: + resolution: {integrity: sha512-fgAN3jGAh+RoxUGZHTSOLJIqUc2wmoBwGR4tbpNAKmmovFoWq0OdRkb0VkldReO2a2iBT/OEulG9XSUc10r3zg==} + + normalize-package-data@8.0.0: + resolution: {integrity: sha512-RWk+PI433eESQ7ounYxIp67CYuVsS1uYSonX3kA6ps/3LWfjVQa/ptEg6Y3T6uAMq1mWpX9PQ+qx+QaHpsc7gQ==} + engines: {node: ^20.17.0 || >=22.9.0} + + object-assign@4.1.1: + resolution: {integrity: sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==} + engines: {node: '>=0.10.0'} + + obug@2.1.1: + resolution: {integrity: sha512-uTqF9MuPraAQ+IsnPf366RG4cP9RtUi7MLO1N3KEc+wb0a6yKpeL0lmk2IB1jY5KHPAlTc6T/JRdC/YqxHNwkQ==} + + oxfmt@0.42.0: + resolution: {integrity: sha512-QhejGErLSMReNuZ6vxgFHDyGoPbjTRNi6uGHjy0cvIjOQFqD6xmr/T+3L41ixR3NIgzcNiJ6ylQKpvShTgDfqg==} + engines: {node: ^20.19.0 || >=22.12.0} + hasBin: true + + oxlint-tsgolint@0.17.3: + resolution: {integrity: sha512-1eh4bcpOMw0e7+YYVxmhFc2mo/V6hJ2+zfukqf+GprvVn3y94b69M/xNrYLmx5A+VdYe0i/bJ2xOs6Hp/jRmRA==} + hasBin: true + + oxlint@1.57.0: + resolution: {integrity: sha512-DGFsuBX5MFZX9yiDdtKjTrYPq45CZ8Fft6qCltJITYZxfwYjVdGf/6wycGYTACloauwIPxUnYhBVeZbHvleGhw==} + engines: {node: ^20.19.0 || >=22.12.0} + hasBin: true + peerDependencies: + oxlint-tsgolint: '>=0.15.0' + peerDependenciesMeta: + oxlint-tsgolint: + optional: true + + pad-right@0.2.2: + resolution: {integrity: sha512-4cy8M95ioIGolCoMmm2cMntGR1lPLEbOMzOKu8bzjuJP6JpzEMQcDHmh7hHLYGgob+nKe1YHFMaG4V59HQa89g==} + engines: {node: '>=0.10.0'} + + parse-json@8.3.0: + resolution: {integrity: sha512-ybiGyvspI+fAoRQbIPRddCcSTV9/LsJbf0e/S85VLowVGzRmokfneg2kwVW/KU5rOXrPSbF1qAKPMgNTqqROQQ==} + engines: {node: '>=18'} + + path-key@3.1.1: + resolution: {integrity: sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==} + engines: {node: '>=8'} + + path-scurry@2.0.2: + resolution: {integrity: sha512-3O/iVVsJAPsOnpwWIeD+d6z/7PmqApyQePUtCndjatj/9I5LylHvt5qluFaBT3I5h3r1ejfR056c+FCv+NnNXg==} + engines: {node: 18 || 20 || >=22} + + picocolors@1.1.1: + resolution: {integrity: sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==} + + picomatch@4.0.4: + resolution: {integrity: sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==} + engines: {node: '>=12'} + + pixelmatch@7.1.0: + resolution: {integrity: sha512-1wrVzJ2STrpmONHKBy228LM1b84msXDUoAzVEl0R8Mz4Ce6EPr+IVtxm8+yvrqLYMHswREkjYFaMxnyGnaY3Ng==} + hasBin: true + + playwright-core@1.51.1: + resolution: {integrity: sha512-/crRMj8+j/Nq5s8QcvegseuyeZPxpQCZb6HNk3Sos3BlZyAknRjoyJPFWkpNn8v0+P3WiwqFF8P+zQo4eqiNuw==} + engines: {node: '>=18'} + hasBin: true + + playwright@1.51.1: + resolution: {integrity: sha512-kkx+MB2KQRkyxjYPc3a0wLZZoDczmppyGJIvQ43l+aZihkaVvmu/21kiyaHeHjiFxjxNNFnUncKmcGIyOojsaw==} + engines: {node: '>=18'} + hasBin: true + + pngjs@7.0.0: + resolution: {integrity: sha512-LKWqWJRhstyYo9pGvgor/ivk2w94eSjE3RGVuzLGlr3NmD8bf7RcYGze1mNdEHRP6TRP6rMuDHk5t44hnTRyow==} + engines: {node: '>=14.19.0'} + + postcss@8.5.8: + resolution: {integrity: sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==} + engines: {node: ^10 || ^12 || >=14} + + progress@2.0.3: + resolution: {integrity: sha512-7PiHtLll5LdnKIMw100I+8xJXR5gW2QwWYkT6iJva0bXitZKa/XMrSbdmg3r2Xnaidz9Qumd0VPaMrZlF9V9sA==} + engines: {node: '>=0.4.0'} + + property-expr@2.0.6: + resolution: {integrity: sha512-SVtmxhRE/CGkn3eZY1T6pC8Nln6Fr/lu1mKSgRud0eC73whjGfoAogbn78LkD8aFL0zz3bAFerKSnOl7NlErBA==} + + read-package-up@12.0.0: + resolution: {integrity: sha512-Q5hMVBYur/eQNWDdbF4/Wqqr9Bjvtrw2kjGxxBbKLbx8bVCL8gcArjTy8zDUuLGQicftpMuU0riQNcAsbtOVsw==} + engines: {node: '>=20'} + + read-pkg@10.1.0: + resolution: {integrity: sha512-I8g2lArQiP78ll51UeMZojewtYgIRCKCWqZEgOO8c/uefTI+XDXvCSXu3+YNUaTNvZzobrL5+SqHjBrByRRTdg==} + engines: {node: '>=20'} + + reflect-metadata@0.2.2: + resolution: {integrity: sha512-urBwgfrvVP/eAyXx4hluJivBKzuEbSQs9rKWCrCkbSxNv8mxPcUZKeuoF3Uy4mJl3Lwprp6yy5/39VWigZ4K6Q==} + + regexp-match-indices@1.0.2: + resolution: {integrity: sha512-DwZuAkt8NF5mKwGGER1EGh2PRqyvhRhhLviH+R8y8dIuaQROlUfXjt4s9ZTXstIsSkptf06BSvwcEmmfheJJWQ==} + + regexp-tree@0.1.27: + resolution: {integrity: sha512-iETxpjK6YoRWJG5o6hXLwvjYAoW+FEZn9os0PD/b6AP6xQwsa/Y7lCVgIixBbUPMfhu+i2LtdeAqVTgGlQarfA==} + hasBin: true + + repeat-string@1.6.1: + resolution: {integrity: sha512-PV0dzCYDNfRi1jCDbJzpW7jNNDRuCOG/jI5ctQcGKt/clZD+YcPS3yIlWuTJMmESC8aevCFmWJy5wjAFgNqN6w==} + engines: {node: '>=0.10'} + + resolve-pkg-maps@1.0.0: + resolution: {integrity: sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw==} + + rolldown@1.0.0-rc.12: + resolution: {integrity: sha512-yP4USLIMYrwpPHEFB5JGH1uxhcslv6/hL0OyvTuY+3qlOSJvZ7ntYnoWpehBxufkgN0cvXxppuTu5hHa/zPh+A==} + engines: {node: ^20.19.0 || >=22.12.0} + hasBin: true + + seed-random@2.2.0: + resolution: {integrity: sha512-34EQV6AAHQGhoc0tn/96a9Fsi6v2xdqe/dMUwljGRaFOzR3EgRmECvD0O8vi8X+/uQ50LGHfkNu/Eue5TPKZkQ==} + + semver@7.7.4: + resolution: {integrity: sha512-vFKC2IEtQnVhpT78h1Yp8wzwrf8CM+MzKMHGJZfBtzhZNycRFnXsHk6E5TxIkkMsgNS7mdX3AGB7x2QM2di4lA==} + engines: {node: '>=10'} + hasBin: true + + shebang-command@2.0.0: + resolution: {integrity: sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==} + engines: {node: '>=8'} + + shebang-regex@3.0.0: + resolution: {integrity: sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==} + engines: {node: '>=8'} + + sirv@3.0.2: + resolution: {integrity: sha512-2wcC/oGxHis/BoHkkPwldgiPSYcpZK3JU28WoMVv55yHJgcZ8rlXvuG9iZggz+sU1d4bRgIGASwyWqjxu3FM0g==} + engines: {node: '>=18'} + + source-map-js@1.2.1: + resolution: {integrity: sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==} + engines: {node: '>=0.10.0'} + + source-map-support@0.5.21: + resolution: {integrity: sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==} + + source-map@0.6.1: + resolution: {integrity: sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==} + engines: {node: '>=0.10.0'} + + spdx-correct@3.2.0: + resolution: {integrity: sha512-kN9dJbvnySHULIluDHy32WHRUu3Og7B9sbY7tsFLctQkIqnMh3hErYgdMjTYuqmcXX+lK5T1lnUt3G7zNswmZA==} + + spdx-exceptions@2.5.0: + resolution: {integrity: sha512-PiU42r+xO4UbUS1buo3LPJkjlO7430Xn5SVAhdpzzsPHsjbYVflnnFdATgabnLude+Cqu25p6N+g2lw/PFsa4w==} + + spdx-expression-parse@3.0.1: + resolution: {integrity: sha512-cbqHunsQWnJNE6KhVSMsMeH5H/L9EpymbzqTQ3uLwNCLZ1Q481oWaofqH7nO6V07xlXwY6PhQdQ2IedWx/ZK4Q==} + + spdx-license-ids@3.0.23: + resolution: {integrity: sha512-CWLcCCH7VLu13TgOH+r8p1O/Znwhqv/dbb6lqWy67G+pT1kHmeD/+V36AVb/vq8QMIQwVShJ6Ssl5FPh0fuSdw==} + + stackframe@1.3.4: + resolution: {integrity: sha512-oeVtt7eWQS+Na6F//S4kJ2K2VbRlS9D43mAlMyVpVWovy9o+jfgH8O9agzANzaiLjclA0oYzUXEM4PurhSUChw==} + + std-env@4.0.0: + resolution: {integrity: sha512-zUMPtQ/HBY3/50VbpkupYHbRroTRZJPRLvreamgErJVys0ceuzMkD44J/QjqhHjOzK42GQ3QZIeFG1OYfOtKqQ==} + + string-argv@0.3.1: + resolution: {integrity: sha512-a1uQGz7IyVy9YwhqjZIZu1c8JO8dNIe20xBmSS6qu9kv++k3JGzCVmprbNN5Kn+BgzD5E7YYwg1CcjuJMRNsvg==} + engines: {node: '>=0.6.19'} + + string-width@4.2.3: + resolution: {integrity: sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==} + engines: {node: '>=8'} + + strip-ansi@6.0.1: + resolution: {integrity: sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==} + engines: {node: '>=8'} + + supports-color@7.2.0: + resolution: {integrity: sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==} + engines: {node: '>=8'} + + supports-color@8.1.1: + resolution: {integrity: sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==} + engines: {node: '>=10'} + + tagged-tag@1.0.0: + resolution: {integrity: sha512-yEFYrVhod+hdNyx7g5Bnkkb0G6si8HJurOoOEgC8B/O0uXLHlaey/65KRv6cuWBNhBgHKAROVpc7QyYqE5gFng==} + engines: {node: '>=20'} + + thenify-all@1.6.0: + resolution: {integrity: sha512-RNxQH/qI8/t3thXJDwcstUO4zeqo64+Uy/+sNVRBx4Xn2OX+OZ9oP+iJnNFqplFra2ZUVeKCSa2oVWi3T4uVmA==} + engines: {node: '>=0.8'} + + thenify@3.3.1: + resolution: {integrity: sha512-RVZSIV5IG10Hk3enotrhvz0T9em6cyHBLkH/YAZuKqd8hRkKhSfCGIcP2KUY0EPxndzANBmNllzWPwak+bheSw==} + + tiny-case@1.0.3: + resolution: {integrity: sha512-Eet/eeMhkO6TX8mnUteS9zgPbUMQa4I6Kkp5ORiBD5476/m+PIRiumP5tmh5ioJpH7k51Kehawy2UDfsnxxY8Q==} + + tinybench@2.9.0: + resolution: {integrity: sha512-0+DUvqWMValLmha6lr4kD8iAMK1HzV0/aKnCtWb9v9641TnP/MFb7Pc2bxoxQjTXAErryXVgUOfv2YqNllqGeg==} + + tinyexec@1.0.4: + resolution: {integrity: sha512-u9r3uZC0bdpGOXtlxUIdwf9pkmvhqJdrVCH9fapQtgy/OeTTMZ1nqH7agtvEfmGui6e1XxjcdrlxvxJvc3sMqw==} + engines: {node: '>=18'} + + tinyglobby@0.2.15: + resolution: {integrity: sha512-j2Zq4NyQYG5XMST4cbs02Ak8iJUdxRM0XI5QyxXuZOzKOINmWurp3smXu3y5wDcJrptwpSjgXHzIQxR0omXljQ==} + engines: {node: '>=12.0.0'} + + tinypool@2.1.0: + resolution: {integrity: sha512-Pugqs6M0m7Lv1I7FtxN4aoyToKg1C4tu+/381vH35y8oENM/Ai7f7C4StcoK4/+BSw9ebcS8jRiVrORFKCALLw==} + engines: {node: ^20.0.0 || >=22.0.0} + + toposort@2.0.2: + resolution: {integrity: sha512-0a5EOkAUp8D4moMi2W8ZF8jcga7BgZd91O/yabJCFY8az+XSzeGyTKs0Aoo897iV1Nj6guFq8orWDS96z91oGg==} + + totalist@3.0.1: + resolution: {integrity: sha512-sf4i37nQ2LBx4m3wB74y+ubopq6W/dIzXg0FDGjsYnZHVa1Da8FH853wlL2gtUhg+xJXjfk3kUZS3BRoQeoQBQ==} + engines: {node: '>=6'} + + ts-dedent@2.2.0: + resolution: {integrity: sha512-q5W7tVM71e2xjHZTlgfTDoPF/SmqKG5hddq9SzR49CH2hayqRKJtQ4mtRlSxKaJlR/+9rEM+mnBHf7I2/BQcpQ==} + engines: {node: '>=6.10'} + + tslib@2.8.1: + resolution: {integrity: sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==} + + tsx@4.21.0: + resolution: {integrity: sha512-5C1sg4USs1lfG0GFb2RLXsdpXqBSEhAaA/0kPL01wxzpMqLILNxIxIOKiILz+cdg/pLnOUxFYOR5yhHU666wbw==} + engines: {node: '>=18.0.0'} + hasBin: true + + type-fest@2.19.0: + resolution: {integrity: sha512-RAH822pAdBgcNMAfWnCBU3CFZcfZ/i1eZjwFU/dsLKumyuuP3niueg2UAukXYF0E2AAoc82ZSSf9J0WQBinzHA==} + engines: {node: '>=12.20'} + + type-fest@4.41.0: + resolution: {integrity: sha512-TeTSQ6H5YHvpqVwBRcnLDCBnDOHWYu7IvGbHT6N8AOymcr9PJGjc1GTtiWZTYg0NCgYwvnYWEkVChQAr9bjfwA==} + engines: {node: '>=16'} + + type-fest@5.5.0: + resolution: {integrity: sha512-PlBfpQwiUvGViBNX84Yxwjsdhd1TUlXr6zjX7eoirtCPIr08NAmxwa+fcYBTeRQxHo9YC9wwF3m9i700sHma8g==} + engines: {node: '>=20'} + + typescript@5.9.3: + resolution: {integrity: sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==} + engines: {node: '>=14.17'} + hasBin: true + + undici-types@7.18.2: + resolution: {integrity: sha512-AsuCzffGHJybSaRrmr5eHr81mwJU3kjw6M+uprWvCXiNeN9SOGwQ3Jn8jb8m3Z6izVgknn1R0FTCEAP2QrLY/w==} + + unicorn-magic@0.4.0: + resolution: {integrity: sha512-wH590V9VNgYH9g3lH9wWjTrUoKsjLF6sGLjhR4sH1LWpLmCOH0Zf7PukhDA8BiS7KHe4oPNkcTHqYkj7SOGUOw==} + engines: {node: '>=20'} + + upper-case-first@2.0.2: + resolution: {integrity: sha512-514ppYHBaKwfJRK/pNC6c/OxfGa0obSnAl106u97Ed0I625Nin96KAjttZF6ZL3e1XLtphxnqrOi9iWgm+u+bg==} + + util-arity@1.1.0: + resolution: {integrity: sha512-kkyIsXKwemfSy8ZEoaIz06ApApnWsk5hQO0vLjZS6UkBiGiW++Jsyb8vSBoc0WKlffGoGs5yYy/j5pp8zckrFA==} + + validate-npm-package-license@3.0.4: + resolution: {integrity: sha512-DpKm2Ui/xN7/HQKCtpZxoRWBhZ9Z0kqtygG8XCgNQ8ZlDnxuQmWhj566j8fN4Cu3/JmbhsDo7fcAJq4s9h27Ew==} + + vite-plus@0.1.14: + resolution: {integrity: sha512-p4pWlpZZNiEsHxPWNdeIU9iuPix3ydm3ficb0dXPggoyIkdotfXtvn2NPX9KwfiQImU72EVEs4+VYBZYNcUYrw==} + engines: {node: ^20.19.0 || >=22.12.0} + hasBin: true + + vite@8.0.3: + resolution: {integrity: sha512-B9ifbFudT1TFhfltfaIPgjo9Z3mDynBTJSUYxTjOQruf/zHH+ezCQKcoqO+h7a9Pw9Nm/OtlXAiGT1axBgwqrQ==} + engines: {node: ^20.19.0 || >=22.12.0} + hasBin: true + peerDependencies: + '@types/node': ^20.19.0 || >=22.12.0 + '@vitejs/devtools': ^0.1.0 + esbuild: ^0.27.0 + jiti: '>=1.21.0' + less: ^4.0.0 + sass: ^1.70.0 + sass-embedded: ^1.70.0 + stylus: '>=0.54.8' + sugarss: ^5.0.0 + terser: ^5.16.0 + tsx: ^4.8.1 + yaml: ^2.4.2 + peerDependenciesMeta: + '@types/node': + optional: true + '@vitejs/devtools': + optional: true + esbuild: + optional: true + jiti: + optional: true + less: + optional: true + sass: + optional: true + sass-embedded: + optional: true + stylus: + optional: true + sugarss: + optional: true + terser: + optional: true + tsx: + optional: true + yaml: + optional: true + + which@2.0.2: + resolution: {integrity: sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==} + engines: {node: '>= 8'} + hasBin: true + + ws@8.20.0: + resolution: {integrity: sha512-sAt8BhgNbzCtgGbt2OxmpuryO63ZoDk/sqaB/znQm94T4fCEsy/yV+7CdC1kJhOU9lboAEU7R3kquuycDoibVA==} + engines: {node: '>=10.0.0'} + peerDependencies: + bufferutil: ^4.0.1 + utf-8-validate: '>=5.0.2' + peerDependenciesMeta: + bufferutil: + optional: true + utf-8-validate: + optional: true + + xmlbuilder@15.1.1: + resolution: {integrity: sha512-yMqGBqtXyeN1e3TGYvgNgDVZ3j84W4cwkOXQswghol6APgZWaff9lnbvN7MHYJOiXsvGPXtjTYJEiC9J2wv9Eg==} + engines: {node: '>=8.0'} + + yaml@2.8.3: + resolution: {integrity: sha512-AvbaCLOO2Otw/lW5bmh9d/WEdcDFdQp2Z2ZUH3pX9U2ihyUY0nvLv7J6TrWowklRGPYbB/IuIMfYgxaCPg5Bpg==} + engines: {node: '>= 14.6'} + hasBin: true + + yup@1.7.1: + resolution: {integrity: sha512-GKHFX2nXul2/4Dtfxhozv701jLQHdf6J34YDh2cEkpqoo8le5Mg6/LrdseVLrFarmFygZTlfIhHx/QKfb/QWXw==} + +snapshots: + + '@babel/code-frame@7.29.0': + dependencies: + '@babel/helper-validator-identifier': 7.28.5 + js-tokens: 4.0.0 + picocolors: 1.1.1 + + '@babel/helper-validator-identifier@7.28.5': {} + + '@colors/colors@1.5.0': + optional: true + + '@cucumber/ci-environment@13.0.0': {} + + '@cucumber/cucumber-expressions@19.0.0': + dependencies: + regexp-match-indices: 1.0.2 + + '@cucumber/cucumber@12.7.0': + dependencies: + '@cucumber/ci-environment': 13.0.0 + '@cucumber/cucumber-expressions': 19.0.0 + '@cucumber/gherkin': 38.0.0 + '@cucumber/gherkin-streams': 6.0.0(@cucumber/gherkin@38.0.0)(@cucumber/message-streams@4.0.1(@cucumber/messages@32.0.1))(@cucumber/messages@32.0.1) + '@cucumber/gherkin-utils': 11.0.0 + '@cucumber/html-formatter': 23.0.0(@cucumber/messages@32.0.1) + '@cucumber/junit-xml-formatter': 0.9.0(@cucumber/messages@32.0.1) + '@cucumber/message-streams': 4.0.1(@cucumber/messages@32.0.1) + '@cucumber/messages': 32.0.1 + '@cucumber/pretty-formatter': 1.0.1(@cucumber/cucumber@12.7.0)(@cucumber/messages@32.0.1) + '@cucumber/tag-expressions': 9.1.0 + assertion-error-formatter: 3.0.0 + capital-case: 1.0.4 + chalk: 4.1.2 + cli-table3: 0.6.5 + commander: 14.0.3 + debug: 4.4.3(supports-color@8.1.1) + error-stack-parser: 2.1.4 + figures: 3.2.0 + glob: 13.0.6 + has-ansi: 4.0.1 + indent-string: 4.0.0 + is-installed-globally: 0.4.0 + is-stream: 2.0.1 + knuth-shuffle-seeded: 1.0.6 + lodash.merge: 4.6.2 + lodash.mergewith: 4.6.2 + luxon: 3.7.2 + mime: 3.0.0 + mkdirp: 3.0.1 + mz: 2.7.0 + progress: 2.0.3 + read-package-up: 12.0.0 + semver: 7.7.4 + string-argv: 0.3.1 + supports-color: 8.1.1 + type-fest: 4.41.0 + util-arity: 1.1.0 + yaml: 2.8.3 + yup: 1.7.1 + + '@cucumber/gherkin-streams@6.0.0(@cucumber/gherkin@38.0.0)(@cucumber/message-streams@4.0.1(@cucumber/messages@32.0.1))(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/gherkin': 38.0.0 + '@cucumber/message-streams': 4.0.1(@cucumber/messages@32.0.1) + '@cucumber/messages': 32.0.1 + commander: 14.0.0 + source-map-support: 0.5.21 + + '@cucumber/gherkin-utils@11.0.0': + dependencies: + '@cucumber/gherkin': 38.0.0 + '@cucumber/messages': 32.0.1 + '@teppeis/multimaps': 3.0.0 + commander: 14.0.2 + source-map-support: 0.5.21 + + '@cucumber/gherkin@38.0.0': + dependencies: + '@cucumber/messages': 32.0.1 + + '@cucumber/html-formatter@23.0.0(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/messages': 32.0.1 + + '@cucumber/junit-xml-formatter@0.9.0(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/messages': 32.0.1 + '@cucumber/query': 14.7.0(@cucumber/messages@32.0.1) + '@teppeis/multimaps': 3.0.0 + luxon: 3.7.2 + xmlbuilder: 15.1.1 + + '@cucumber/message-streams@4.0.1(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/messages': 32.0.1 + + '@cucumber/messages@32.0.1': + dependencies: + class-transformer: 0.5.1 + reflect-metadata: 0.2.2 + + '@cucumber/pretty-formatter@1.0.1(@cucumber/cucumber@12.7.0)(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/cucumber': 12.7.0 + '@cucumber/messages': 32.0.1 + ansi-styles: 5.2.0 + cli-table3: 0.6.5 + figures: 3.2.0 + ts-dedent: 2.2.0 + + '@cucumber/query@14.7.0(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/messages': 32.0.1 + '@teppeis/multimaps': 3.0.0 + lodash.sortby: 4.7.0 + + '@cucumber/tag-expressions@9.1.0': {} + + '@emnapi/core@1.9.1': + dependencies: + '@emnapi/wasi-threads': 1.2.0 + tslib: 2.8.1 + optional: true + + '@emnapi/runtime@1.9.1': + dependencies: + tslib: 2.8.1 + optional: true + + '@emnapi/wasi-threads@1.2.0': + dependencies: + tslib: 2.8.1 + optional: true + + '@esbuild/aix-ppc64@0.27.4': + optional: true + + '@esbuild/android-arm64@0.27.4': + optional: true + + '@esbuild/android-arm@0.27.4': + optional: true + + '@esbuild/android-x64@0.27.4': + optional: true + + '@esbuild/darwin-arm64@0.27.4': + optional: true + + '@esbuild/darwin-x64@0.27.4': + optional: true + + '@esbuild/freebsd-arm64@0.27.4': + optional: true + + '@esbuild/freebsd-x64@0.27.4': + optional: true + + '@esbuild/linux-arm64@0.27.4': + optional: true + + '@esbuild/linux-arm@0.27.4': + optional: true + + '@esbuild/linux-ia32@0.27.4': + optional: true + + '@esbuild/linux-loong64@0.27.4': + optional: true + + '@esbuild/linux-mips64el@0.27.4': + optional: true + + '@esbuild/linux-ppc64@0.27.4': + optional: true + + '@esbuild/linux-riscv64@0.27.4': + optional: true + + '@esbuild/linux-s390x@0.27.4': + optional: true + + '@esbuild/linux-x64@0.27.4': + optional: true + + '@esbuild/netbsd-arm64@0.27.4': + optional: true + + '@esbuild/netbsd-x64@0.27.4': + optional: true + + '@esbuild/openbsd-arm64@0.27.4': + optional: true + + '@esbuild/openbsd-x64@0.27.4': + optional: true + + '@esbuild/openharmony-arm64@0.27.4': + optional: true + + '@esbuild/sunos-x64@0.27.4': + optional: true + + '@esbuild/win32-arm64@0.27.4': + optional: true + + '@esbuild/win32-ia32@0.27.4': + optional: true + + '@esbuild/win32-x64@0.27.4': + optional: true + + '@napi-rs/wasm-runtime@1.1.1': + dependencies: + '@emnapi/core': 1.9.1 + '@emnapi/runtime': 1.9.1 + '@tybys/wasm-util': 0.10.1 + optional: true + + '@oxc-project/runtime@0.121.0': {} + + '@oxc-project/types@0.122.0': {} + + '@oxfmt/binding-android-arm-eabi@0.42.0': + optional: true + + '@oxfmt/binding-android-arm64@0.42.0': + optional: true + + '@oxfmt/binding-darwin-arm64@0.42.0': + optional: true + + '@oxfmt/binding-darwin-x64@0.42.0': + optional: true + + '@oxfmt/binding-freebsd-x64@0.42.0': + optional: true + + '@oxfmt/binding-linux-arm-gnueabihf@0.42.0': + optional: true + + '@oxfmt/binding-linux-arm-musleabihf@0.42.0': + optional: true + + '@oxfmt/binding-linux-arm64-gnu@0.42.0': + optional: true + + '@oxfmt/binding-linux-arm64-musl@0.42.0': + optional: true + + '@oxfmt/binding-linux-ppc64-gnu@0.42.0': + optional: true + + '@oxfmt/binding-linux-riscv64-gnu@0.42.0': + optional: true + + '@oxfmt/binding-linux-riscv64-musl@0.42.0': + optional: true + + '@oxfmt/binding-linux-s390x-gnu@0.42.0': + optional: true + + '@oxfmt/binding-linux-x64-gnu@0.42.0': + optional: true + + '@oxfmt/binding-linux-x64-musl@0.42.0': + optional: true + + '@oxfmt/binding-openharmony-arm64@0.42.0': + optional: true + + '@oxfmt/binding-win32-arm64-msvc@0.42.0': + optional: true + + '@oxfmt/binding-win32-ia32-msvc@0.42.0': + optional: true + + '@oxfmt/binding-win32-x64-msvc@0.42.0': + optional: true + + '@oxlint-tsgolint/darwin-arm64@0.17.3': + optional: true + + '@oxlint-tsgolint/darwin-x64@0.17.3': + optional: true + + '@oxlint-tsgolint/linux-arm64@0.17.3': + optional: true + + '@oxlint-tsgolint/linux-x64@0.17.3': + optional: true + + '@oxlint-tsgolint/win32-arm64@0.17.3': + optional: true + + '@oxlint-tsgolint/win32-x64@0.17.3': + optional: true + + '@oxlint/binding-android-arm-eabi@1.57.0': + optional: true + + '@oxlint/binding-android-arm64@1.57.0': + optional: true + + '@oxlint/binding-darwin-arm64@1.57.0': + optional: true + + '@oxlint/binding-darwin-x64@1.57.0': + optional: true + + '@oxlint/binding-freebsd-x64@1.57.0': + optional: true + + '@oxlint/binding-linux-arm-gnueabihf@1.57.0': + optional: true + + '@oxlint/binding-linux-arm-musleabihf@1.57.0': + optional: true + + '@oxlint/binding-linux-arm64-gnu@1.57.0': + optional: true + + '@oxlint/binding-linux-arm64-musl@1.57.0': + optional: true + + '@oxlint/binding-linux-ppc64-gnu@1.57.0': + optional: true + + '@oxlint/binding-linux-riscv64-gnu@1.57.0': + optional: true + + '@oxlint/binding-linux-riscv64-musl@1.57.0': + optional: true + + '@oxlint/binding-linux-s390x-gnu@1.57.0': + optional: true + + '@oxlint/binding-linux-x64-gnu@1.57.0': + optional: true + + '@oxlint/binding-linux-x64-musl@1.57.0': + optional: true + + '@oxlint/binding-openharmony-arm64@1.57.0': + optional: true + + '@oxlint/binding-win32-arm64-msvc@1.57.0': + optional: true + + '@oxlint/binding-win32-ia32-msvc@1.57.0': + optional: true + + '@oxlint/binding-win32-x64-msvc@1.57.0': + optional: true + + '@playwright/test@1.51.1': + dependencies: + playwright: 1.51.1 + + '@polka/url@1.0.0-next.29': {} + + '@rolldown/binding-android-arm64@1.0.0-rc.12': + optional: true + + '@rolldown/binding-darwin-arm64@1.0.0-rc.12': + optional: true + + '@rolldown/binding-darwin-x64@1.0.0-rc.12': + optional: true + + '@rolldown/binding-freebsd-x64@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-arm-gnueabihf@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-arm64-gnu@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-arm64-musl@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-ppc64-gnu@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-s390x-gnu@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-x64-gnu@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-x64-musl@1.0.0-rc.12': + optional: true + + '@rolldown/binding-openharmony-arm64@1.0.0-rc.12': + optional: true + + '@rolldown/binding-wasm32-wasi@1.0.0-rc.12': + dependencies: + '@napi-rs/wasm-runtime': 1.1.1 + optional: true + + '@rolldown/binding-win32-arm64-msvc@1.0.0-rc.12': + optional: true + + '@rolldown/binding-win32-x64-msvc@1.0.0-rc.12': + optional: true + + '@rolldown/pluginutils@1.0.0-rc.12': {} + + '@standard-schema/spec@1.1.0': {} + + '@teppeis/multimaps@3.0.0': {} + + '@tybys/wasm-util@0.10.1': + dependencies: + tslib: 2.8.1 + optional: true + + '@types/chai@5.2.3': + dependencies: + '@types/deep-eql': 4.0.2 + assertion-error: 2.0.1 + + '@types/deep-eql@4.0.2': {} + + '@types/node@25.5.0': + dependencies: + undici-types: 7.18.2 + + '@types/normalize-package-data@2.4.4': {} + + '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)': + dependencies: + '@oxc-project/runtime': 0.121.0 + '@oxc-project/types': 0.122.0 + lightningcss: 1.32.0 + postcss: 8.5.8 + optionalDependencies: + '@types/node': 25.5.0 + esbuild: 0.27.4 + fsevents: 2.3.3 + tsx: 4.21.0 + typescript: 5.9.3 + yaml: 2.8.3 + + '@voidzero-dev/vite-plus-darwin-arm64@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-darwin-x64@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-linux-arm64-gnu@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-linux-arm64-musl@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-linux-x64-gnu@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-linux-x64-musl@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3)': + dependencies: + '@standard-schema/spec': 1.1.0 + '@types/chai': 5.2.3 + '@voidzero-dev/vite-plus-core': 0.1.14(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) + es-module-lexer: 1.7.0 + obug: 2.1.1 + pixelmatch: 7.1.0 + pngjs: 7.0.0 + sirv: 3.0.2 + std-env: 4.0.0 + tinybench: 2.9.0 + tinyexec: 1.0.4 + tinyglobby: 0.2.15 + vite: 8.0.3(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(yaml@2.8.3) + ws: 8.20.0 + optionalDependencies: + '@types/node': 25.5.0 + transitivePeerDependencies: + - '@arethetypeswrong/core' + - '@tsdown/css' + - '@tsdown/exe' + - '@vitejs/devtools' + - bufferutil + - esbuild + - jiti + - less + - publint + - sass + - sass-embedded + - stylus + - sugarss + - terser + - tsx + - typescript + - unplugin-unused + - utf-8-validate + - yaml + + '@voidzero-dev/vite-plus-win32-arm64-msvc@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-win32-x64-msvc@0.1.14': + optional: true + + ansi-regex@4.1.1: {} + + ansi-regex@5.0.1: {} + + ansi-styles@4.3.0: + dependencies: + color-convert: 2.0.1 + + ansi-styles@5.2.0: {} + + any-promise@1.3.0: {} + + assertion-error-formatter@3.0.0: + dependencies: + diff: 4.0.4 + pad-right: 0.2.2 + repeat-string: 1.6.1 + + assertion-error@2.0.1: {} + + balanced-match@4.0.4: {} + + brace-expansion@5.0.5: + dependencies: + balanced-match: 4.0.4 + + buffer-from@1.1.2: {} + + cac@7.0.0: {} + + capital-case@1.0.4: + dependencies: + no-case: 3.0.4 + tslib: 2.8.1 + upper-case-first: 2.0.2 + + chalk@4.1.2: + dependencies: + ansi-styles: 4.3.0 + supports-color: 7.2.0 + + class-transformer@0.5.1: {} + + cli-table3@0.6.5: + dependencies: + string-width: 4.2.3 + optionalDependencies: + '@colors/colors': 1.5.0 + + color-convert@2.0.1: + dependencies: + color-name: 1.1.4 + + color-name@1.1.4: {} + + commander@14.0.0: {} + + commander@14.0.2: {} + + commander@14.0.3: {} + + cross-spawn@7.0.6: + dependencies: + path-key: 3.1.1 + shebang-command: 2.0.0 + which: 2.0.2 + + debug@4.4.3(supports-color@8.1.1): + dependencies: + ms: 2.1.3 + optionalDependencies: + supports-color: 8.1.1 + + detect-libc@2.1.2: {} + + diff@4.0.4: {} + + emoji-regex@8.0.0: {} + + error-stack-parser@2.1.4: + dependencies: + stackframe: 1.3.4 + + es-module-lexer@1.7.0: {} + + esbuild@0.27.4: + optionalDependencies: + '@esbuild/aix-ppc64': 0.27.4 + '@esbuild/android-arm': 0.27.4 + '@esbuild/android-arm64': 0.27.4 + '@esbuild/android-x64': 0.27.4 + '@esbuild/darwin-arm64': 0.27.4 + '@esbuild/darwin-x64': 0.27.4 + '@esbuild/freebsd-arm64': 0.27.4 + '@esbuild/freebsd-x64': 0.27.4 + '@esbuild/linux-arm': 0.27.4 + '@esbuild/linux-arm64': 0.27.4 + '@esbuild/linux-ia32': 0.27.4 + '@esbuild/linux-loong64': 0.27.4 + '@esbuild/linux-mips64el': 0.27.4 + '@esbuild/linux-ppc64': 0.27.4 + '@esbuild/linux-riscv64': 0.27.4 + '@esbuild/linux-s390x': 0.27.4 + '@esbuild/linux-x64': 0.27.4 + '@esbuild/netbsd-arm64': 0.27.4 + '@esbuild/netbsd-x64': 0.27.4 + '@esbuild/openbsd-arm64': 0.27.4 + '@esbuild/openbsd-x64': 0.27.4 + '@esbuild/openharmony-arm64': 0.27.4 + '@esbuild/sunos-x64': 0.27.4 + '@esbuild/win32-arm64': 0.27.4 + '@esbuild/win32-ia32': 0.27.4 + '@esbuild/win32-x64': 0.27.4 + + escape-string-regexp@1.0.5: {} + + fdir@6.5.0(picomatch@4.0.4): + optionalDependencies: + picomatch: 4.0.4 + + figures@3.2.0: + dependencies: + escape-string-regexp: 1.0.5 + + find-up-simple@1.0.1: {} + + fsevents@2.3.2: + optional: true + + fsevents@2.3.3: + optional: true + + get-tsconfig@4.13.7: + dependencies: + resolve-pkg-maps: 1.0.0 + + glob@13.0.6: + dependencies: + minimatch: 10.2.4 + minipass: 7.1.3 + path-scurry: 2.0.2 + + global-dirs@3.0.1: + dependencies: + ini: 2.0.0 + + has-ansi@4.0.1: + dependencies: + ansi-regex: 4.1.1 + + has-flag@4.0.0: {} + + hosted-git-info@9.0.2: + dependencies: + lru-cache: 11.2.7 + + indent-string@4.0.0: {} + + index-to-position@1.2.0: {} + + ini@2.0.0: {} + + is-fullwidth-code-point@3.0.0: {} + + is-installed-globally@0.4.0: + dependencies: + global-dirs: 3.0.1 + is-path-inside: 3.0.3 + + is-path-inside@3.0.3: {} + + is-stream@2.0.1: {} + + isexe@2.0.0: {} + + js-tokens@4.0.0: {} + + knuth-shuffle-seeded@1.0.6: + dependencies: + seed-random: 2.2.0 + + lightningcss-android-arm64@1.32.0: + optional: true + + lightningcss-darwin-arm64@1.32.0: + optional: true + + lightningcss-darwin-x64@1.32.0: + optional: true + + lightningcss-freebsd-x64@1.32.0: + optional: true + + lightningcss-linux-arm-gnueabihf@1.32.0: + optional: true + + lightningcss-linux-arm64-gnu@1.32.0: + optional: true + + lightningcss-linux-arm64-musl@1.32.0: + optional: true + + lightningcss-linux-x64-gnu@1.32.0: + optional: true + + lightningcss-linux-x64-musl@1.32.0: + optional: true + + lightningcss-win32-arm64-msvc@1.32.0: + optional: true + + lightningcss-win32-x64-msvc@1.32.0: + optional: true + + lightningcss@1.32.0: + dependencies: + detect-libc: 2.1.2 + optionalDependencies: + lightningcss-android-arm64: 1.32.0 + lightningcss-darwin-arm64: 1.32.0 + lightningcss-darwin-x64: 1.32.0 + lightningcss-freebsd-x64: 1.32.0 + lightningcss-linux-arm-gnueabihf: 1.32.0 + lightningcss-linux-arm64-gnu: 1.32.0 + lightningcss-linux-arm64-musl: 1.32.0 + lightningcss-linux-x64-gnu: 1.32.0 + lightningcss-linux-x64-musl: 1.32.0 + lightningcss-win32-arm64-msvc: 1.32.0 + lightningcss-win32-x64-msvc: 1.32.0 + + lodash.merge@4.6.2: {} + + lodash.mergewith@4.6.2: {} + + lodash.sortby@4.7.0: {} + + lower-case@2.0.2: + dependencies: + tslib: 2.8.1 + + lru-cache@11.2.7: {} + + luxon@3.7.2: {} + + mime@3.0.0: {} + + minimatch@10.2.4: + dependencies: + brace-expansion: 5.0.5 + + minipass@7.1.3: {} + + mkdirp@3.0.1: {} + + mrmime@2.0.1: {} + + ms@2.1.3: {} + + mz@2.7.0: + dependencies: + any-promise: 1.3.0 + object-assign: 4.1.1 + thenify-all: 1.6.0 + + nanoid@3.3.11: {} + + no-case@3.0.4: + dependencies: + lower-case: 2.0.2 + tslib: 2.8.1 + + normalize-package-data@8.0.0: + dependencies: + hosted-git-info: 9.0.2 + semver: 7.7.4 + validate-npm-package-license: 3.0.4 + + object-assign@4.1.1: {} + + obug@2.1.1: {} + + oxfmt@0.42.0: + dependencies: + tinypool: 2.1.0 + optionalDependencies: + '@oxfmt/binding-android-arm-eabi': 0.42.0 + '@oxfmt/binding-android-arm64': 0.42.0 + '@oxfmt/binding-darwin-arm64': 0.42.0 + '@oxfmt/binding-darwin-x64': 0.42.0 + '@oxfmt/binding-freebsd-x64': 0.42.0 + '@oxfmt/binding-linux-arm-gnueabihf': 0.42.0 + '@oxfmt/binding-linux-arm-musleabihf': 0.42.0 + '@oxfmt/binding-linux-arm64-gnu': 0.42.0 + '@oxfmt/binding-linux-arm64-musl': 0.42.0 + '@oxfmt/binding-linux-ppc64-gnu': 0.42.0 + '@oxfmt/binding-linux-riscv64-gnu': 0.42.0 + '@oxfmt/binding-linux-riscv64-musl': 0.42.0 + '@oxfmt/binding-linux-s390x-gnu': 0.42.0 + '@oxfmt/binding-linux-x64-gnu': 0.42.0 + '@oxfmt/binding-linux-x64-musl': 0.42.0 + '@oxfmt/binding-openharmony-arm64': 0.42.0 + '@oxfmt/binding-win32-arm64-msvc': 0.42.0 + '@oxfmt/binding-win32-ia32-msvc': 0.42.0 + '@oxfmt/binding-win32-x64-msvc': 0.42.0 + + oxlint-tsgolint@0.17.3: + optionalDependencies: + '@oxlint-tsgolint/darwin-arm64': 0.17.3 + '@oxlint-tsgolint/darwin-x64': 0.17.3 + '@oxlint-tsgolint/linux-arm64': 0.17.3 + '@oxlint-tsgolint/linux-x64': 0.17.3 + '@oxlint-tsgolint/win32-arm64': 0.17.3 + '@oxlint-tsgolint/win32-x64': 0.17.3 + + oxlint@1.57.0(oxlint-tsgolint@0.17.3): + optionalDependencies: + '@oxlint/binding-android-arm-eabi': 1.57.0 + '@oxlint/binding-android-arm64': 1.57.0 + '@oxlint/binding-darwin-arm64': 1.57.0 + '@oxlint/binding-darwin-x64': 1.57.0 + '@oxlint/binding-freebsd-x64': 1.57.0 + '@oxlint/binding-linux-arm-gnueabihf': 1.57.0 + '@oxlint/binding-linux-arm-musleabihf': 1.57.0 + '@oxlint/binding-linux-arm64-gnu': 1.57.0 + '@oxlint/binding-linux-arm64-musl': 1.57.0 + '@oxlint/binding-linux-ppc64-gnu': 1.57.0 + '@oxlint/binding-linux-riscv64-gnu': 1.57.0 + '@oxlint/binding-linux-riscv64-musl': 1.57.0 + '@oxlint/binding-linux-s390x-gnu': 1.57.0 + '@oxlint/binding-linux-x64-gnu': 1.57.0 + '@oxlint/binding-linux-x64-musl': 1.57.0 + '@oxlint/binding-openharmony-arm64': 1.57.0 + '@oxlint/binding-win32-arm64-msvc': 1.57.0 + '@oxlint/binding-win32-ia32-msvc': 1.57.0 + '@oxlint/binding-win32-x64-msvc': 1.57.0 + oxlint-tsgolint: 0.17.3 + + pad-right@0.2.2: + dependencies: + repeat-string: 1.6.1 + + parse-json@8.3.0: + dependencies: + '@babel/code-frame': 7.29.0 + index-to-position: 1.2.0 + type-fest: 4.41.0 + + path-key@3.1.1: {} + + path-scurry@2.0.2: + dependencies: + lru-cache: 11.2.7 + minipass: 7.1.3 + + picocolors@1.1.1: {} + + picomatch@4.0.4: {} + + pixelmatch@7.1.0: + dependencies: + pngjs: 7.0.0 + + playwright-core@1.51.1: {} + + playwright@1.51.1: + dependencies: + playwright-core: 1.51.1 + optionalDependencies: + fsevents: 2.3.2 + + pngjs@7.0.0: {} + + postcss@8.5.8: + dependencies: + nanoid: 3.3.11 + picocolors: 1.1.1 + source-map-js: 1.2.1 + + progress@2.0.3: {} + + property-expr@2.0.6: {} + + read-package-up@12.0.0: + dependencies: + find-up-simple: 1.0.1 + read-pkg: 10.1.0 + type-fest: 5.5.0 + + read-pkg@10.1.0: + dependencies: + '@types/normalize-package-data': 2.4.4 + normalize-package-data: 8.0.0 + parse-json: 8.3.0 + type-fest: 5.5.0 + unicorn-magic: 0.4.0 + + reflect-metadata@0.2.2: {} + + regexp-match-indices@1.0.2: + dependencies: + regexp-tree: 0.1.27 + + regexp-tree@0.1.27: {} + + repeat-string@1.6.1: {} + + resolve-pkg-maps@1.0.0: {} + + rolldown@1.0.0-rc.12: + dependencies: + '@oxc-project/types': 0.122.0 + '@rolldown/pluginutils': 1.0.0-rc.12 + optionalDependencies: + '@rolldown/binding-android-arm64': 1.0.0-rc.12 + '@rolldown/binding-darwin-arm64': 1.0.0-rc.12 + '@rolldown/binding-darwin-x64': 1.0.0-rc.12 + '@rolldown/binding-freebsd-x64': 1.0.0-rc.12 + '@rolldown/binding-linux-arm-gnueabihf': 1.0.0-rc.12 + '@rolldown/binding-linux-arm64-gnu': 1.0.0-rc.12 + '@rolldown/binding-linux-arm64-musl': 1.0.0-rc.12 + '@rolldown/binding-linux-ppc64-gnu': 1.0.0-rc.12 + '@rolldown/binding-linux-s390x-gnu': 1.0.0-rc.12 + '@rolldown/binding-linux-x64-gnu': 1.0.0-rc.12 + '@rolldown/binding-linux-x64-musl': 1.0.0-rc.12 + '@rolldown/binding-openharmony-arm64': 1.0.0-rc.12 + '@rolldown/binding-wasm32-wasi': 1.0.0-rc.12 + '@rolldown/binding-win32-arm64-msvc': 1.0.0-rc.12 + '@rolldown/binding-win32-x64-msvc': 1.0.0-rc.12 + + seed-random@2.2.0: {} + + semver@7.7.4: {} + + shebang-command@2.0.0: + dependencies: + shebang-regex: 3.0.0 + + shebang-regex@3.0.0: {} + + sirv@3.0.2: + dependencies: + '@polka/url': 1.0.0-next.29 + mrmime: 2.0.1 + totalist: 3.0.1 + + source-map-js@1.2.1: {} + + source-map-support@0.5.21: + dependencies: + buffer-from: 1.1.2 + source-map: 0.6.1 + + source-map@0.6.1: {} + + spdx-correct@3.2.0: + dependencies: + spdx-expression-parse: 3.0.1 + spdx-license-ids: 3.0.23 + + spdx-exceptions@2.5.0: {} + + spdx-expression-parse@3.0.1: + dependencies: + spdx-exceptions: 2.5.0 + spdx-license-ids: 3.0.23 + + spdx-license-ids@3.0.23: {} + + stackframe@1.3.4: {} + + std-env@4.0.0: {} + + string-argv@0.3.1: {} + + string-width@4.2.3: + dependencies: + emoji-regex: 8.0.0 + is-fullwidth-code-point: 3.0.0 + strip-ansi: 6.0.1 + + strip-ansi@6.0.1: + dependencies: + ansi-regex: 5.0.1 + + supports-color@7.2.0: + dependencies: + has-flag: 4.0.0 + + supports-color@8.1.1: + dependencies: + has-flag: 4.0.0 + + tagged-tag@1.0.0: {} + + thenify-all@1.6.0: + dependencies: + thenify: 3.3.1 + + thenify@3.3.1: + dependencies: + any-promise: 1.3.0 + + tiny-case@1.0.3: {} + + tinybench@2.9.0: {} + + tinyexec@1.0.4: {} + + tinyglobby@0.2.15: + dependencies: + fdir: 6.5.0(picomatch@4.0.4) + picomatch: 4.0.4 + + tinypool@2.1.0: {} + + toposort@2.0.2: {} + + totalist@3.0.1: {} + + ts-dedent@2.2.0: {} + + tslib@2.8.1: {} + + tsx@4.21.0: + dependencies: + esbuild: 0.27.4 + get-tsconfig: 4.13.7 + optionalDependencies: + fsevents: 2.3.3 + + type-fest@2.19.0: {} + + type-fest@4.41.0: {} + + type-fest@5.5.0: + dependencies: + tagged-tag: 1.0.0 + + typescript@5.9.3: {} + + undici-types@7.18.2: {} + + unicorn-magic@0.4.0: {} + + upper-case-first@2.0.2: + dependencies: + tslib: 2.8.1 + + util-arity@1.1.0: {} + + validate-npm-package-license@3.0.4: + dependencies: + spdx-correct: 3.2.0 + spdx-expression-parse: 3.0.1 + + vite-plus@0.1.14(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3): + dependencies: + '@oxc-project/types': 0.122.0 + '@voidzero-dev/vite-plus-core': 0.1.14(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) + '@voidzero-dev/vite-plus-test': 0.1.14(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3) + cac: 7.0.0 + cross-spawn: 7.0.6 + oxfmt: 0.42.0 + oxlint: 1.57.0(oxlint-tsgolint@0.17.3) + oxlint-tsgolint: 0.17.3 + picocolors: 1.1.1 + optionalDependencies: + '@voidzero-dev/vite-plus-darwin-arm64': 0.1.14 + '@voidzero-dev/vite-plus-darwin-x64': 0.1.14 + '@voidzero-dev/vite-plus-linux-arm64-gnu': 0.1.14 + '@voidzero-dev/vite-plus-linux-arm64-musl': 0.1.14 + '@voidzero-dev/vite-plus-linux-x64-gnu': 0.1.14 + '@voidzero-dev/vite-plus-linux-x64-musl': 0.1.14 + '@voidzero-dev/vite-plus-win32-arm64-msvc': 0.1.14 + '@voidzero-dev/vite-plus-win32-x64-msvc': 0.1.14 + transitivePeerDependencies: + - '@arethetypeswrong/core' + - '@edge-runtime/vm' + - '@opentelemetry/api' + - '@tsdown/css' + - '@tsdown/exe' + - '@types/node' + - '@vitejs/devtools' + - '@vitest/ui' + - bufferutil + - esbuild + - happy-dom + - jiti + - jsdom + - less + - publint + - sass + - sass-embedded + - stylus + - sugarss + - terser + - tsx + - typescript + - unplugin-unused + - utf-8-validate + - vite + - yaml + + vite@8.0.3(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(yaml@2.8.3): + dependencies: + lightningcss: 1.32.0 + picomatch: 4.0.4 + postcss: 8.5.8 + rolldown: 1.0.0-rc.12 + tinyglobby: 0.2.15 + optionalDependencies: + '@types/node': 25.5.0 + esbuild: 0.27.4 + fsevents: 2.3.3 + tsx: 4.21.0 + yaml: 2.8.3 + + which@2.0.2: + dependencies: + isexe: 2.0.0 + + ws@8.20.0: {} + + xmlbuilder@15.1.1: {} + + yaml@2.8.3: {} + + yup@1.7.1: + dependencies: + property-expr: 2.0.6 + tiny-case: 1.0.3 + toposort: 2.0.2 + type-fest: 2.19.0 diff --git a/e2e/scripts/common.ts b/e2e/scripts/common.ts new file mode 100644 index 0000000000..bb82121079 --- /dev/null +++ b/e2e/scripts/common.ts @@ -0,0 +1,242 @@ +import { spawn, type ChildProcess } from 'node:child_process' +import { access, copyFile, readFile, writeFile } from 'node:fs/promises' +import net from 'node:net' +import path from 'node:path' +import { fileURLToPath, pathToFileURL } from 'node:url' +import { sleep } from '../support/process' + +type RunCommandOptions = { + command: string + args: string[] + cwd: string + env?: NodeJS.ProcessEnv + stdio?: 'inherit' | 'pipe' +} + +type RunCommandResult = { + exitCode: number + stdout: string + stderr: string +} + +type ForegroundProcessOptions = { + command: string + args: string[] + cwd: string + env?: NodeJS.ProcessEnv +} + +export const rootDir = fileURLToPath(new URL('../..', import.meta.url)) +export const e2eDir = path.join(rootDir, 'e2e') +export const apiDir = path.join(rootDir, 'api') +export const dockerDir = path.join(rootDir, 'docker') +export const webDir = path.join(rootDir, 'web') + +export const middlewareComposeFile = path.join(dockerDir, 'docker-compose.middleware.yaml') +export const middlewareEnvFile = path.join(dockerDir, 'middleware.env') +export const middlewareEnvExampleFile = path.join(dockerDir, 'middleware.env.example') +export const webEnvLocalFile = path.join(webDir, '.env.local') +export const webEnvExampleFile = path.join(webDir, '.env.example') +export const apiEnvExampleFile = path.join(apiDir, 'tests', 'integration_tests', '.env.example') + +const formatCommand = (command: string, args: string[]) => [command, ...args].join(' ') + +export const isMainModule = (metaUrl: string) => { + const entrypoint = process.argv[1] + if (!entrypoint) return false + + return pathToFileURL(entrypoint).href === metaUrl +} + +export const runCommand = async ({ + command, + args, + cwd, + env, + stdio = 'inherit', +}: RunCommandOptions): Promise => { + const childProcess = spawn(command, args, { + cwd, + env: { + ...process.env, + ...env, + }, + stdio: stdio === 'inherit' ? 'inherit' : 'pipe', + }) + + let stdout = '' + let stderr = '' + + if (stdio === 'pipe') { + childProcess.stdout?.on('data', (chunk: Buffer | string) => { + stdout += chunk.toString() + }) + childProcess.stderr?.on('data', (chunk: Buffer | string) => { + stderr += chunk.toString() + }) + } + + return await new Promise((resolve, reject) => { + childProcess.once('error', reject) + childProcess.once('exit', (code) => { + resolve({ + exitCode: code ?? 1, + stdout, + stderr, + }) + }) + }) +} + +export const runCommandOrThrow = async (options: RunCommandOptions) => { + const result = await runCommand(options) + + if (result.exitCode !== 0) { + throw new Error( + `Command failed (${result.exitCode}): ${formatCommand(options.command, options.args)}`, + ) + } + + return result +} + +const forwardSignalsToChild = (childProcess: ChildProcess) => { + const handleSignal = (signal: NodeJS.Signals) => { + if (childProcess.exitCode === null) childProcess.kill(signal) + } + + const onSigint = () => handleSignal('SIGINT') + const onSigterm = () => handleSignal('SIGTERM') + + process.on('SIGINT', onSigint) + process.on('SIGTERM', onSigterm) + + return () => { + process.off('SIGINT', onSigint) + process.off('SIGTERM', onSigterm) + } +} + +export const runForegroundProcess = async ({ + command, + args, + cwd, + env, +}: ForegroundProcessOptions) => { + const childProcess = spawn(command, args, { + cwd, + env: { + ...process.env, + ...env, + }, + stdio: 'inherit', + }) + + const cleanupSignals = forwardSignalsToChild(childProcess) + const exitCode = await new Promise((resolve, reject) => { + childProcess.once('error', reject) + childProcess.once('exit', (code) => { + resolve(code ?? 1) + }) + }) + + cleanupSignals() + process.exit(exitCode) +} + +export const ensureFileExists = async (filePath: string, exampleFilePath: string) => { + try { + await access(filePath) + } catch { + await copyFile(exampleFilePath, filePath) + } +} + +export const ensureLineInFile = async (filePath: string, line: string) => { + const fileContent = await readFile(filePath, 'utf8') + const lines = fileContent.split(/\r?\n/) + const assignmentPrefix = line.includes('=') ? `${line.slice(0, line.indexOf('='))}=` : null + + if (lines.includes(line)) return + + if (assignmentPrefix && lines.some((existingLine) => existingLine.startsWith(assignmentPrefix))) + return + + const normalizedContent = fileContent.endsWith('\n') ? fileContent : `${fileContent}\n` + await writeFile(filePath, `${normalizedContent}${line}\n`, 'utf8') +} + +export const ensureWebEnvLocal = async () => { + await ensureFileExists(webEnvLocalFile, webEnvExampleFile) + + const fileContent = await readFile(webEnvLocalFile, 'utf8') + const nextContent = fileContent.replaceAll('http://localhost:5001', 'http://127.0.0.1:5001') + + if (nextContent !== fileContent) await writeFile(webEnvLocalFile, nextContent, 'utf8') +} + +export const readSimpleDotenv = async (filePath: string) => { + const fileContent = await readFile(filePath, 'utf8') + const entries = fileContent + .split(/\r?\n/) + .map((line) => line.trim()) + .filter((line) => line && !line.startsWith('#')) + .map<[string, string]>((line) => { + const separatorIndex = line.indexOf('=') + const key = separatorIndex === -1 ? line : line.slice(0, separatorIndex).trim() + const rawValue = separatorIndex === -1 ? '' : line.slice(separatorIndex + 1).trim() + + if ( + (rawValue.startsWith('"') && rawValue.endsWith('"')) || + (rawValue.startsWith("'") && rawValue.endsWith("'")) + ) { + return [key, rawValue.slice(1, -1)] + } + + return [key, rawValue] + }) + + return Object.fromEntries(entries) +} + +export const waitForCondition = async ({ + check, + description, + intervalMs, + timeoutMs, +}: { + check: () => Promise | boolean + description: string + intervalMs: number + timeoutMs: number +}) => { + const deadline = Date.now() + timeoutMs + + while (Date.now() < deadline) { + if (await check()) return + + await sleep(intervalMs) + } + + throw new Error(`Timed out waiting for ${description} after ${timeoutMs}ms.`) +} + +export const isTcpPortReachable = async (host: string, port: number, timeoutMs = 1_000) => { + return await new Promise((resolve) => { + const socket = net.createConnection({ + host, + port, + }) + + const finish = (result: boolean) => { + socket.removeAllListeners() + socket.destroy() + resolve(result) + } + + socket.setTimeout(timeoutMs) + socket.once('connect', () => finish(true)) + socket.once('timeout', () => finish(false)) + socket.once('error', () => finish(false)) + }) +} diff --git a/e2e/scripts/run-cucumber.ts b/e2e/scripts/run-cucumber.ts new file mode 100644 index 0000000000..39e9157916 --- /dev/null +++ b/e2e/scripts/run-cucumber.ts @@ -0,0 +1,154 @@ +import { mkdir, rm } from 'node:fs/promises' +import path from 'node:path' +import { startWebServer, stopWebServer } from '../support/web-server' +import { waitForUrl, startLoggedProcess, stopManagedProcess } from '../support/process' +import { apiURL, baseURL, reuseExistingWebServer } from '../test-env' +import { e2eDir, isMainModule, runCommand } from './common' +import { resetState, startMiddleware, stopMiddleware } from './setup' + +type RunOptions = { + forwardArgs: string[] + full: boolean + headed: boolean +} + +const parseArgs = (argv: string[]): RunOptions => { + let full = false + let headed = false + const forwardArgs: string[] = [] + + for (let index = 0; index < argv.length; index += 1) { + const arg = argv[index] + + if (arg === '--') { + forwardArgs.push(...argv.slice(index + 1)) + break + } + + if (arg === '--full') { + full = true + continue + } + + if (arg === '--headed') { + headed = true + continue + } + + forwardArgs.push(arg) + } + + return { + forwardArgs, + full, + headed, + } +} + +const hasCustomTags = (forwardArgs: string[]) => + forwardArgs.some((arg) => arg === '--tags' || arg.startsWith('--tags=')) + +const main = async () => { + const { forwardArgs, full, headed } = parseArgs(process.argv.slice(2)) + const startMiddlewareForRun = full + const resetStateForRun = full + + if (resetStateForRun) await resetState() + + if (startMiddlewareForRun) await startMiddleware() + + const cucumberReportDir = path.join(e2eDir, 'cucumber-report') + const logDir = path.join(e2eDir, '.logs') + + await rm(cucumberReportDir, { force: true, recursive: true }) + await mkdir(logDir, { recursive: true }) + + const apiProcess = await startLoggedProcess({ + command: 'npx', + args: ['tsx', './scripts/setup.ts', 'api'], + cwd: e2eDir, + label: 'api server', + logFilePath: path.join(logDir, 'cucumber-api.log'), + }) + + let cleanupPromise: Promise | undefined + const cleanup = async () => { + if (!cleanupPromise) { + cleanupPromise = (async () => { + await stopWebServer() + await stopManagedProcess(apiProcess) + + if (startMiddlewareForRun) { + try { + await stopMiddleware() + } catch { + // Cleanup should continue even if middleware shutdown fails. + } + } + })() + } + + await cleanupPromise + } + + const onTerminate = () => { + void cleanup().finally(() => { + process.exit(1) + }) + } + + process.once('SIGINT', onTerminate) + process.once('SIGTERM', onTerminate) + + try { + try { + await waitForUrl(`${apiURL}/health`, 180_000, 1_000) + } catch { + throw new Error(`API did not become ready at ${apiURL}/health.`) + } + + await startWebServer({ + baseURL, + command: 'npx', + args: ['tsx', './scripts/setup.ts', 'web'], + cwd: e2eDir, + logFilePath: path.join(logDir, 'cucumber-web.log'), + reuseExistingServer: reuseExistingWebServer, + timeoutMs: 300_000, + }) + + const cucumberEnv: NodeJS.ProcessEnv = { + ...process.env, + CUCUMBER_HEADLESS: headed ? '0' : '1', + } + + if (startMiddlewareForRun && !hasCustomTags(forwardArgs)) + cucumberEnv.E2E_CUCUMBER_TAGS = 'not @skip' + + const result = await runCommand({ + command: 'npx', + args: [ + 'tsx', + './node_modules/@cucumber/cucumber/bin/cucumber.js', + '--config', + './cucumber.config.ts', + ...forwardArgs, + ], + cwd: e2eDir, + env: cucumberEnv, + }) + + process.exitCode = result.exitCode + } finally { + process.off('SIGINT', onTerminate) + process.off('SIGTERM', onTerminate) + await cleanup() + } +} + +if (isMainModule(import.meta.url)) { + void main().catch((error) => { + console.error(error instanceof Error ? error.message : String(error)) + process.exit(1) + }) +} diff --git a/e2e/scripts/setup.ts b/e2e/scripts/setup.ts new file mode 100644 index 0000000000..6f38598df4 --- /dev/null +++ b/e2e/scripts/setup.ts @@ -0,0 +1,306 @@ +import { access, mkdir, rm } from 'node:fs/promises' +import path from 'node:path' +import { waitForUrl } from '../support/process' +import { + apiDir, + apiEnvExampleFile, + dockerDir, + e2eDir, + ensureFileExists, + ensureLineInFile, + ensureWebEnvLocal, + isMainModule, + isTcpPortReachable, + middlewareComposeFile, + middlewareEnvExampleFile, + middlewareEnvFile, + readSimpleDotenv, + runCommand, + runCommandOrThrow, + runForegroundProcess, + waitForCondition, + webDir, +} from './common' + +const buildIdPath = path.join(webDir, '.next', 'BUILD_ID') + +const middlewareDataPaths = [ + path.join(dockerDir, 'volumes', 'db', 'data'), + path.join(dockerDir, 'volumes', 'plugin_daemon'), + path.join(dockerDir, 'volumes', 'redis', 'data'), + path.join(dockerDir, 'volumes', 'weaviate'), +] + +const e2eStatePaths = [ + path.join(e2eDir, '.auth'), + path.join(e2eDir, 'cucumber-report'), + path.join(e2eDir, '.logs'), + path.join(e2eDir, 'playwright-report'), + path.join(e2eDir, 'test-results'), +] + +const composeArgs = [ + 'compose', + '-f', + middlewareComposeFile, + '--profile', + 'postgresql', + '--profile', + 'weaviate', +] + +const getApiEnvironment = async () => { + const envFromExample = await readSimpleDotenv(apiEnvExampleFile) + + return { + ...envFromExample, + FLASK_APP: 'app.py', + } +} + +const getServiceContainerId = async (service: string) => { + const result = await runCommandOrThrow({ + command: 'docker', + args: ['compose', '-f', middlewareComposeFile, 'ps', '-q', service], + cwd: dockerDir, + stdio: 'pipe', + }) + + return result.stdout.trim() +} + +const getContainerHealth = async (containerId: string) => { + const result = await runCommand({ + command: 'docker', + args: ['inspect', '-f', '{{.State.Health.Status}}', containerId], + cwd: dockerDir, + stdio: 'pipe', + }) + + if (result.exitCode !== 0) return '' + + return result.stdout.trim() +} + +const printComposeLogs = async (services: string[]) => { + await runCommand({ + command: 'docker', + args: ['compose', '-f', middlewareComposeFile, 'logs', ...services], + cwd: dockerDir, + }) +} + +const waitForDependency = async ({ + description, + services, + wait, +}: { + description: string + services: string[] + wait: () => Promise +}) => { + console.log(`Waiting for ${description}...`) + + try { + await wait() + } catch (error) { + await printComposeLogs(services) + throw error + } +} + +export const ensureWebBuild = async () => { + await ensureWebEnvLocal() + + if (process.env.E2E_FORCE_WEB_BUILD === '1') { + await runCommandOrThrow({ + command: 'pnpm', + args: ['run', 'build'], + cwd: webDir, + }) + return + } + + try { + await access(buildIdPath) + console.log('Reusing existing web build artifact.') + } catch { + await runCommandOrThrow({ + command: 'pnpm', + args: ['run', 'build'], + cwd: webDir, + }) + } +} + +export const startWeb = async () => { + await ensureWebBuild() + + await runForegroundProcess({ + command: 'pnpm', + args: ['run', 'start'], + cwd: webDir, + env: { + HOSTNAME: '127.0.0.1', + PORT: '3000', + }, + }) +} + +export const startApi = async () => { + const env = await getApiEnvironment() + + await runCommandOrThrow({ + command: 'uv', + args: ['run', '--project', '.', 'flask', 'upgrade-db'], + cwd: apiDir, + env, + }) + + await runForegroundProcess({ + command: 'uv', + args: ['run', '--project', '.', 'flask', 'run', '--host', '127.0.0.1', '--port', '5001'], + cwd: apiDir, + env, + }) +} + +export const stopMiddleware = async () => { + await runCommandOrThrow({ + command: 'docker', + args: [...composeArgs, 'down', '--remove-orphans'], + cwd: dockerDir, + }) +} + +export const resetState = async () => { + console.log('Stopping middleware services...') + try { + await stopMiddleware() + } catch { + // Reset should continue even if middleware is already stopped. + } + + console.log('Removing persisted middleware data...') + await Promise.all( + middlewareDataPaths.map(async (targetPath) => { + await rm(targetPath, { force: true, recursive: true }) + await mkdir(targetPath, { recursive: true }) + }), + ) + + console.log('Removing E2E local state...') + await Promise.all( + e2eStatePaths.map((targetPath) => rm(targetPath, { force: true, recursive: true })), + ) + + console.log('E2E state reset complete.') +} + +export const startMiddleware = async () => { + await ensureFileExists(middlewareEnvFile, middlewareEnvExampleFile) + await ensureLineInFile(middlewareEnvFile, 'COMPOSE_PROFILES=postgresql,weaviate') + + console.log('Starting middleware services...') + await runCommandOrThrow({ + command: 'docker', + args: [ + ...composeArgs, + 'up', + '-d', + 'db_postgres', + 'redis', + 'weaviate', + 'sandbox', + 'ssrf_proxy', + 'plugin_daemon', + ], + cwd: dockerDir, + }) + + const [postgresContainerId, redisContainerId] = await Promise.all([ + getServiceContainerId('db_postgres'), + getServiceContainerId('redis'), + ]) + + await waitForDependency({ + description: 'PostgreSQL and Redis health checks', + services: ['db_postgres', 'redis'], + wait: () => + waitForCondition({ + check: async () => { + const [postgresStatus, redisStatus] = await Promise.all([ + getContainerHealth(postgresContainerId), + getContainerHealth(redisContainerId), + ]) + + return postgresStatus === 'healthy' && redisStatus === 'healthy' + }, + description: 'PostgreSQL and Redis health checks', + intervalMs: 2_000, + timeoutMs: 240_000, + }), + }) + + await waitForDependency({ + description: 'Weaviate readiness', + services: ['weaviate'], + wait: () => waitForUrl('http://127.0.0.1:8080/v1/.well-known/ready', 120_000, 2_000), + }) + + await waitForDependency({ + description: 'sandbox health', + services: ['sandbox', 'ssrf_proxy'], + wait: () => waitForUrl('http://127.0.0.1:8194/health', 120_000, 2_000), + }) + + await waitForDependency({ + description: 'plugin daemon port', + services: ['plugin_daemon'], + wait: () => + waitForCondition({ + check: async () => isTcpPortReachable('127.0.0.1', 5002), + description: 'plugin daemon port', + intervalMs: 2_000, + timeoutMs: 120_000, + }), + }) + + console.log('Full middleware stack is ready.') +} + +const printUsage = () => { + console.log('Usage: tsx ./scripts/setup.ts ') +} + +const main = async () => { + const command = process.argv[2] + + switch (command) { + case 'api': + await startApi() + return + case 'middleware-down': + await stopMiddleware() + return + case 'middleware-up': + await startMiddleware() + return + case 'reset': + await resetState() + return + case 'web': + await startWeb() + return + default: + printUsage() + process.exitCode = 1 + } +} + +if (isMainModule(import.meta.url)) { + void main().catch((error) => { + console.error(error instanceof Error ? error.message : String(error)) + process.exit(1) + }) +} diff --git a/e2e/support/process.ts b/e2e/support/process.ts new file mode 100644 index 0000000000..96273ef931 --- /dev/null +++ b/e2e/support/process.ts @@ -0,0 +1,178 @@ +import type { ChildProcess } from 'node:child_process' +import { spawn } from 'node:child_process' +import { createWriteStream, type WriteStream } from 'node:fs' +import { mkdir } from 'node:fs/promises' +import net from 'node:net' +import { dirname } from 'node:path' + +type ManagedProcessOptions = { + command: string + args?: string[] + cwd: string + env?: NodeJS.ProcessEnv + label: string + logFilePath: string +} + +export type ManagedProcess = { + childProcess: ChildProcess + label: string + logFilePath: string + logStream: WriteStream +} + +export const sleep = (ms: number) => + new Promise((resolve) => { + setTimeout(resolve, ms) + }) + +export const isPortReachable = async (host: string, port: number, timeoutMs = 1_000) => { + return await new Promise((resolve) => { + const socket = net.createConnection({ + host, + port, + }) + + const finish = (result: boolean) => { + socket.removeAllListeners() + socket.destroy() + resolve(result) + } + + socket.setTimeout(timeoutMs) + socket.once('connect', () => finish(true)) + socket.once('timeout', () => finish(false)) + socket.once('error', () => finish(false)) + }) +} + +export const waitForUrl = async ( + url: string, + timeoutMs: number, + intervalMs = 1_000, + requestTimeoutMs = Math.max(intervalMs, 1_000), +) => { + const deadline = Date.now() + timeoutMs + + while (Date.now() < deadline) { + try { + const controller = new AbortController() + const timeout = setTimeout(() => controller.abort(), requestTimeoutMs) + + try { + const response = await fetch(url, { + signal: controller.signal, + }) + if (response.ok) return + } finally { + clearTimeout(timeout) + } + } catch { + // Keep polling until timeout. + } + + await sleep(intervalMs) + } + + throw new Error(`Timed out waiting for ${url} after ${timeoutMs}ms.`) +} + +export const startLoggedProcess = async ({ + command, + args = [], + cwd, + env, + label, + logFilePath, +}: ManagedProcessOptions): Promise => { + await mkdir(dirname(logFilePath), { recursive: true }) + + const logStream = createWriteStream(logFilePath, { flags: 'a' }) + const childProcess = spawn(command, args, { + cwd, + env: { + ...process.env, + ...env, + }, + detached: process.platform !== 'win32', + stdio: ['ignore', 'pipe', 'pipe'], + }) + + const formattedCommand = [command, ...args].join(' ') + logStream.write(`[${new Date().toISOString()}] Starting ${label}: ${formattedCommand}\n`) + childProcess.stdout?.pipe(logStream, { end: false }) + childProcess.stderr?.pipe(logStream, { end: false }) + + return { + childProcess, + label, + logFilePath, + logStream, + } +} + +const waitForProcessExit = (childProcess: ChildProcess, timeoutMs: number) => + new Promise((resolve) => { + if (childProcess.exitCode !== null) { + resolve() + return + } + + const timeout = setTimeout(() => { + cleanup() + resolve() + }, timeoutMs) + + const onExit = () => { + cleanup() + resolve() + } + + const cleanup = () => { + clearTimeout(timeout) + childProcess.off('exit', onExit) + } + + childProcess.once('exit', onExit) + }) + +const signalManagedProcess = (childProcess: ChildProcess, signal: NodeJS.Signals) => { + const { pid } = childProcess + if (!pid) return + + try { + if (process.platform !== 'win32') { + process.kill(-pid, signal) + return + } + + childProcess.kill(signal) + } catch { + // Best-effort shutdown. Cleanup continues even when the process is already gone. + } +} + +export const stopManagedProcess = async (managedProcess?: ManagedProcess) => { + if (!managedProcess) return + + const { childProcess, logStream } = managedProcess + + if (childProcess.exitCode === null) { + signalManagedProcess(childProcess, 'SIGTERM') + await waitForProcessExit(childProcess, 5_000) + } + + if (childProcess.exitCode === null) { + signalManagedProcess(childProcess, 'SIGKILL') + await waitForProcessExit(childProcess, 5_000) + } + + childProcess.stdout?.unpipe(logStream) + childProcess.stderr?.unpipe(logStream) + childProcess.stdout?.destroy() + childProcess.stderr?.destroy() + + await new Promise((resolve) => { + logStream.end(() => resolve()) + }) +} diff --git a/e2e/support/web-server.ts b/e2e/support/web-server.ts new file mode 100644 index 0000000000..ad5d5d916a --- /dev/null +++ b/e2e/support/web-server.ts @@ -0,0 +1,83 @@ +import type { ManagedProcess } from './process' +import { isPortReachable, startLoggedProcess, stopManagedProcess, waitForUrl } from './process' + +type WebServerStartOptions = { + baseURL: string + command: string + args?: string[] + cwd: string + logFilePath: string + reuseExistingServer: boolean + timeoutMs: number +} + +let activeProcess: ManagedProcess | undefined + +const getUrlHostAndPort = (url: string) => { + const parsedUrl = new URL(url) + const isHttps = parsedUrl.protocol === 'https:' + + return { + host: parsedUrl.hostname, + port: parsedUrl.port ? Number(parsedUrl.port) : isHttps ? 443 : 80, + } +} + +export const startWebServer = async ({ + baseURL, + command, + args = [], + cwd, + logFilePath, + reuseExistingServer, + timeoutMs, +}: WebServerStartOptions) => { + const { host, port } = getUrlHostAndPort(baseURL) + + if (reuseExistingServer && (await isPortReachable(host, port))) return + + activeProcess = await startLoggedProcess({ + command, + args, + cwd, + label: 'web server', + logFilePath, + }) + + let startupError: Error | undefined + activeProcess.childProcess.once('error', (error) => { + startupError = error + }) + activeProcess.childProcess.once('exit', (code, signal) => { + if (startupError) return + + startupError = new Error( + `Web server exited before readiness (code: ${code ?? 'unknown'}, signal: ${signal ?? 'none'}).`, + ) + }) + + const deadline = Date.now() + timeoutMs + while (Date.now() < deadline) { + if (startupError) { + await stopManagedProcess(activeProcess) + activeProcess = undefined + throw startupError + } + + try { + await waitForUrl(baseURL, 1_000, 250, 1_000) + return + } catch { + // Continue polling until timeout or child exit. + } + } + + await stopManagedProcess(activeProcess) + activeProcess = undefined + throw new Error(`Timed out waiting for web server readiness at ${baseURL} after ${timeoutMs}ms.`) +} + +export const stopWebServer = async () => { + await stopManagedProcess(activeProcess) + activeProcess = undefined +} diff --git a/e2e/test-env.ts b/e2e/test-env.ts new file mode 100644 index 0000000000..c0afc2a8c1 --- /dev/null +++ b/e2e/test-env.ts @@ -0,0 +1,12 @@ +export const defaultBaseURL = 'http://127.0.0.1:3000' +export const defaultApiURL = 'http://127.0.0.1:5001' +export const defaultLocale = 'en-US' + +export const baseURL = process.env.E2E_BASE_URL || defaultBaseURL +export const apiURL = process.env.E2E_API_URL || defaultApiURL + +export const cucumberHeadless = process.env.CUCUMBER_HEADLESS !== '0' +export const cucumberSlowMo = Number(process.env.E2E_SLOW_MO || 0) +export const reuseExistingWebServer = process.env.E2E_REUSE_WEB_SERVER + ? process.env.E2E_REUSE_WEB_SERVER !== '0' + : !process.env.CI diff --git a/e2e/tsconfig.json b/e2e/tsconfig.json new file mode 100644 index 0000000000..3976c12b66 --- /dev/null +++ b/e2e/tsconfig.json @@ -0,0 +1,25 @@ +{ + "compilerOptions": { + "target": "ES2023", + "lib": ["ES2023", "DOM"], + "module": "ESNext", + "moduleResolution": "Bundler", + "allowJs": false, + "resolveJsonModule": true, + "noEmit": true, + "strict": true, + "skipLibCheck": true, + "types": ["node", "@playwright/test", "@cucumber/cucumber"], + "isolatedModules": true, + "verbatimModuleSyntax": true + }, + "include": ["./**/*.ts"], + "exclude": [ + "./node_modules", + "./playwright-report", + "./test-results", + "./.auth", + "./cucumber-report", + "./.logs" + ] +} diff --git a/e2e/vite.config.ts b/e2e/vite.config.ts new file mode 100644 index 0000000000..98400d5b9b --- /dev/null +++ b/e2e/vite.config.ts @@ -0,0 +1,15 @@ +import { defineConfig } from 'vite-plus' + +export default defineConfig({ + lint: { + options: { + typeAware: true, + typeCheck: true, + denyWarnings: true, + }, + }, + fmt: { + singleQuote: true, + semi: false, + }, +}) diff --git a/sdks/nodejs-client/pnpm-lock.yaml b/sdks/nodejs-client/pnpm-lock.yaml index 722fe5b1bc..30d3cf61ee 100644 --- a/sdks/nodejs-client/pnpm-lock.yaml +++ b/sdks/nodejs-client/pnpm-lock.yaml @@ -326,79 +326,66 @@ packages: resolution: {integrity: sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==} cpu: [arm] os: [linux] - libc: [glibc] '@rollup/rollup-linux-arm-musleabihf@4.59.0': resolution: {integrity: sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==} cpu: [arm] os: [linux] - libc: [musl] '@rollup/rollup-linux-arm64-gnu@4.59.0': resolution: {integrity: sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==} cpu: [arm64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-arm64-musl@4.59.0': resolution: {integrity: sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==} cpu: [arm64] os: [linux] - libc: [musl] '@rollup/rollup-linux-loong64-gnu@4.59.0': resolution: {integrity: sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==} cpu: [loong64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-loong64-musl@4.59.0': resolution: {integrity: sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==} cpu: [loong64] os: [linux] - libc: [musl] '@rollup/rollup-linux-ppc64-gnu@4.59.0': resolution: {integrity: sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==} cpu: [ppc64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-ppc64-musl@4.59.0': resolution: {integrity: sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==} cpu: [ppc64] os: [linux] - libc: [musl] '@rollup/rollup-linux-riscv64-gnu@4.59.0': resolution: {integrity: sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==} cpu: [riscv64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-riscv64-musl@4.59.0': resolution: {integrity: sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==} cpu: [riscv64] os: [linux] - libc: [musl] '@rollup/rollup-linux-s390x-gnu@4.59.0': resolution: {integrity: sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==} cpu: [s390x] os: [linux] - libc: [glibc] '@rollup/rollup-linux-x64-gnu@4.59.0': resolution: {integrity: sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==} cpu: [x64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-x64-musl@4.59.0': resolution: {integrity: sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==} cpu: [x64] os: [linux] - libc: [musl] '@rollup/rollup-openbsd-x64@4.59.0': resolution: {integrity: sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==} diff --git a/web/__tests__/check-i18n.test.ts b/web/__tests__/check-i18n.test.ts deleted file mode 100644 index 9e9b3d7168..0000000000 --- a/web/__tests__/check-i18n.test.ts +++ /dev/null @@ -1,860 +0,0 @@ -import fs from 'node:fs' -import path from 'node:path' -import vm from 'node:vm' -import { transpile } from 'typescript' - -describe('i18n:check script functionality', () => { - const testDir = path.join(__dirname, '../i18n-test') - const testEnDir = path.join(testDir, 'en-US') - const testZhDir = path.join(testDir, 'zh-Hans') - - // Helper function that replicates the getKeysFromLanguage logic - async function getKeysFromLanguage(language: string, testPath = testDir): Promise { - return new Promise((resolve, reject) => { - const folderPath = path.resolve(testPath, language) - const allKeys: string[] = [] - - if (!fs.existsSync(folderPath)) { - resolve([]) - return - } - - fs.readdir(folderPath, (err, files) => { - if (err) { - reject(err) - return - } - - const translationFiles = files.filter(file => /\.(ts|js)$/.test(file)) - - translationFiles.forEach((file) => { - const filePath = path.join(folderPath, file) - const fileName = file.replace(/\.[^/.]+$/, '') - const camelCaseFileName = fileName.replace(/[-_](.)/g, (_, c) => - c.toUpperCase()) - - try { - const content = fs.readFileSync(filePath, 'utf8') - const moduleExports = {} - const context = { - exports: moduleExports, - module: { exports: moduleExports }, - require, - console, - __filename: filePath, - __dirname: folderPath, - } - - vm.runInNewContext(transpile(content), context) - const translationObj = (context.module.exports as any).default || context.module.exports - - if (!translationObj || typeof translationObj !== 'object') - throw new Error(`Error parsing file: ${filePath}`) - - const nestedKeys: string[] = [] - const iterateKeys = (obj: any, prefix = '') => { - for (const key in obj) { - const nestedKey = prefix ? `${prefix}.${key}` : key - if (typeof obj[key] === 'object' && obj[key] !== null && !Array.isArray(obj[key])) { - // This is an object (but not array), recurse into it but don't add it as a key - iterateKeys(obj[key], nestedKey) - } - else { - // This is a leaf node (string, number, boolean, array, etc.), add it as a key - nestedKeys.push(nestedKey) - } - } - } - iterateKeys(translationObj) - - const fileKeys = nestedKeys.map(key => `${camelCaseFileName}.${key}`) - allKeys.push(...fileKeys) - } - catch (error) { - reject(error) - } - }) - resolve(allKeys) - }) - }) - } - - beforeEach(() => { - // Clean up and create test directories - if (fs.existsSync(testDir)) - fs.rmSync(testDir, { recursive: true }) - - fs.mkdirSync(testDir, { recursive: true }) - fs.mkdirSync(testEnDir, { recursive: true }) - fs.mkdirSync(testZhDir, { recursive: true }) - }) - - afterEach(() => { - // Clean up test files - if (fs.existsSync(testDir)) - fs.rmSync(testDir, { recursive: true }) - }) - - describe('Key extraction logic', () => { - it('should extract only leaf node keys, not intermediate objects', async () => { - const testContent = `const translation = { - simple: 'Simple Value', - nested: { - level1: 'Level 1 Value', - deep: { - level2: 'Level 2 Value' - } - }, - array: ['not extracted'], - number: 42, - boolean: true -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'test.ts'), testContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toEqual([ - 'test.simple', - 'test.nested.level1', - 'test.nested.deep.level2', - 'test.array', - 'test.number', - 'test.boolean', - ]) - - // Should not include intermediate object keys - expect(keys).not.toContain('test.nested') - expect(keys).not.toContain('test.nested.deep') - }) - - it('should handle camelCase file name conversion correctly', async () => { - const testContent = `const translation = { - key: 'value' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'app-debug.ts'), testContent) - fs.writeFileSync(path.join(testEnDir, 'user_profile.ts'), testContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('appDebug.key') - expect(keys).toContain('userProfile.key') - }) - }) - - describe('Missing keys detection', () => { - it('should detect missing keys in target language', async () => { - const enContent = `const translation = { - common: { - save: 'Save', - cancel: 'Cancel', - delete: 'Delete' - }, - app: { - title: 'My App', - version: '1.0' - } -} - -export default translation -` - - const zhContent = `const translation = { - common: { - save: '保存', - cancel: '取消' - // missing 'delete' - }, - app: { - title: '我的应用' - // missing 'version' - } -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'test.ts'), enContent) - fs.writeFileSync(path.join(testZhDir, 'test.ts'), zhContent) - - const enKeys = await getKeysFromLanguage('en-US') - const zhKeys = await getKeysFromLanguage('zh-Hans') - - const missingKeys = enKeys.filter(key => !zhKeys.includes(key)) - - expect(missingKeys).toContain('test.common.delete') - expect(missingKeys).toContain('test.app.version') - expect(missingKeys).toHaveLength(2) - }) - }) - - describe('Extra keys detection', () => { - it('should detect extra keys in target language', async () => { - const enContent = `const translation = { - common: { - save: 'Save', - cancel: 'Cancel' - } -} - -export default translation -` - - const zhContent = `const translation = { - common: { - save: '保存', - cancel: '取消', - delete: '删除', // extra key - extra: '额外的' // another extra key - }, - newSection: { - someKey: '某个值' // extra section - } -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'test.ts'), enContent) - fs.writeFileSync(path.join(testZhDir, 'test.ts'), zhContent) - - const enKeys = await getKeysFromLanguage('en-US') - const zhKeys = await getKeysFromLanguage('zh-Hans') - - const extraKeys = zhKeys.filter(key => !enKeys.includes(key)) - - expect(extraKeys).toContain('test.common.delete') - expect(extraKeys).toContain('test.common.extra') - expect(extraKeys).toContain('test.newSection.someKey') - expect(extraKeys).toHaveLength(3) - }) - }) - - describe('File filtering logic', () => { - it('should filter keys by specific file correctly', async () => { - // Create multiple files - const file1Content = `const translation = { - button: 'Button', - text: 'Text' -} - -export default translation -` - - const file2Content = `const translation = { - title: 'Title', - description: 'Description' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'components.ts'), file1Content) - fs.writeFileSync(path.join(testEnDir, 'pages.ts'), file2Content) - fs.writeFileSync(path.join(testZhDir, 'components.ts'), file1Content) - fs.writeFileSync(path.join(testZhDir, 'pages.ts'), file2Content) - - const allEnKeys = await getKeysFromLanguage('en-US') - - // Test file filtering logic - const targetFile = 'components' - const filteredEnKeys = allEnKeys.filter(key => - key.startsWith(targetFile.replace(/[-_](.)/g, (_, c) => c.toUpperCase())), - ) - - expect(allEnKeys).toHaveLength(4) // 2 keys from each file - expect(filteredEnKeys).toHaveLength(2) // only components keys - expect(filteredEnKeys).toContain('components.button') - expect(filteredEnKeys).toContain('components.text') - expect(filteredEnKeys).not.toContain('pages.title') - expect(filteredEnKeys).not.toContain('pages.description') - }) - }) - - describe('Complex nested structure handling', () => { - it('should handle deeply nested objects correctly', async () => { - const complexContent = `const translation = { - level1: { - level2: { - level3: { - level4: { - deepValue: 'Deep Value' - }, - anotherValue: 'Another Value' - }, - simpleValue: 'Simple Value' - }, - directValue: 'Direct Value' - }, - rootValue: 'Root Value' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'complex.ts'), complexContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('complex.level1.level2.level3.level4.deepValue') - expect(keys).toContain('complex.level1.level2.level3.anotherValue') - expect(keys).toContain('complex.level1.level2.simpleValue') - expect(keys).toContain('complex.level1.directValue') - expect(keys).toContain('complex.rootValue') - - // Should not include intermediate objects - expect(keys).not.toContain('complex.level1') - expect(keys).not.toContain('complex.level1.level2') - expect(keys).not.toContain('complex.level1.level2.level3') - expect(keys).not.toContain('complex.level1.level2.level3.level4') - }) - }) - - describe('Edge cases', () => { - it('should handle empty objects', async () => { - const emptyContent = `const translation = { - empty: {}, - withValue: 'value' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'empty.ts'), emptyContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('empty.withValue') - expect(keys).not.toContain('empty.empty') - }) - - it('should handle special characters in keys', async () => { - const specialContent = `const translation = { - 'key-with-dash': 'value1', - 'key_with_underscore': 'value2', - 'key.with.dots': 'value3', - normalKey: 'value4' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'special.ts'), specialContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('special.key-with-dash') - expect(keys).toContain('special.key_with_underscore') - expect(keys).toContain('special.key.with.dots') - expect(keys).toContain('special.normalKey') - }) - - it('should handle different value types', async () => { - const typesContent = `const translation = { - stringValue: 'string', - numberValue: 42, - booleanValue: true, - nullValue: null, - undefinedValue: undefined, - arrayValue: ['array', 'values'], - objectValue: { - nested: 'nested value' - } -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'types.ts'), typesContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('types.stringValue') - expect(keys).toContain('types.numberValue') - expect(keys).toContain('types.booleanValue') - expect(keys).toContain('types.nullValue') - expect(keys).toContain('types.undefinedValue') - expect(keys).toContain('types.arrayValue') - expect(keys).toContain('types.objectValue.nested') - expect(keys).not.toContain('types.objectValue') - }) - }) - - describe('Real-world scenario tests', () => { - it('should handle app-debug structure like real files', async () => { - const appDebugEn = `const translation = { - pageTitle: { - line1: 'Prompt', - line2: 'Engineering' - }, - operation: { - applyConfig: 'Publish', - resetConfig: 'Reset', - debugConfig: 'Debug' - }, - generate: { - instruction: 'Instructions', - generate: 'Generate', - resTitle: 'Generated Prompt', - noDataLine1: 'Describe your use case on the left,', - noDataLine2: 'the orchestration preview will show here.' - } -} - -export default translation -` - - const appDebugZh = `const translation = { - pageTitle: { - line1: '提示词', - line2: '编排' - }, - operation: { - applyConfig: '发布', - resetConfig: '重置', - debugConfig: '调试' - }, - generate: { - instruction: '指令', - generate: '生成', - resTitle: '生成的提示词', - noData: '在左侧描述您的用例,编排预览将在此处显示。' // This is extra - } -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'app-debug.ts'), appDebugEn) - fs.writeFileSync(path.join(testZhDir, 'app-debug.ts'), appDebugZh) - - const enKeys = await getKeysFromLanguage('en-US') - const zhKeys = await getKeysFromLanguage('zh-Hans') - - const missingKeys = enKeys.filter(key => !zhKeys.includes(key)) - const extraKeys = zhKeys.filter(key => !enKeys.includes(key)) - - expect(missingKeys).toContain('appDebug.generate.noDataLine1') - expect(missingKeys).toContain('appDebug.generate.noDataLine2') - expect(extraKeys).toContain('appDebug.generate.noData') - - expect(missingKeys).toHaveLength(2) - expect(extraKeys).toHaveLength(1) - }) - - it('should handle time structure with operation nested keys', async () => { - const timeEn = `const translation = { - months: { - January: 'January', - February: 'February' - }, - operation: { - now: 'Now', - ok: 'OK', - cancel: 'Cancel', - pickDate: 'Pick Date' - }, - title: { - pickTime: 'Pick Time' - }, - defaultPlaceholder: 'Pick a time...' -} - -export default translation -` - - const timeZh = `const translation = { - months: { - January: '一月', - February: '二月' - }, - operation: { - now: '此刻', - ok: '确定', - cancel: '取消', - pickDate: '选择日期' - }, - title: { - pickTime: '选择时间' - }, - pickDate: '选择日期', // This is extra - duplicates operation.pickDate - defaultPlaceholder: '请选择时间...' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'time.ts'), timeEn) - fs.writeFileSync(path.join(testZhDir, 'time.ts'), timeZh) - - const enKeys = await getKeysFromLanguage('en-US') - const zhKeys = await getKeysFromLanguage('zh-Hans') - - const missingKeys = enKeys.filter(key => !zhKeys.includes(key)) - const extraKeys = zhKeys.filter(key => !enKeys.includes(key)) - - expect(missingKeys).toHaveLength(0) // No missing keys - expect(extraKeys).toContain('time.pickDate') // Extra root-level pickDate - expect(extraKeys).toHaveLength(1) - - // Should have both keys available - expect(zhKeys).toContain('time.operation.pickDate') // Correct nested key - expect(zhKeys).toContain('time.pickDate') // Extra duplicate key - }) - }) - - describe('Statistics calculation', () => { - it('should calculate correct difference statistics', async () => { - const enContent = `const translation = { - key1: 'value1', - key2: 'value2', - key3: 'value3' -} - -export default translation -` - - const zhContentMissing = `const translation = { - key1: 'value1', - key2: 'value2' - // missing key3 -} - -export default translation -` - - const zhContentExtra = `const translation = { - key1: 'value1', - key2: 'value2', - key3: 'value3', - key4: 'extra', - key5: 'extra2' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'stats.ts'), enContent) - - // Test missing keys scenario - fs.writeFileSync(path.join(testZhDir, 'stats.ts'), zhContentMissing) - - const enKeys = await getKeysFromLanguage('en-US') - const zhKeysMissing = await getKeysFromLanguage('zh-Hans') - - expect(enKeys.length - zhKeysMissing.length).toBe(1) // +1 means 1 missing key - - // Test extra keys scenario - fs.writeFileSync(path.join(testZhDir, 'stats.ts'), zhContentExtra) - - const zhKeysExtra = await getKeysFromLanguage('zh-Hans') - - expect(enKeys.length - zhKeysExtra.length).toBe(-2) // -2 means 2 extra keys - }) - }) - - describe('Auto-remove multiline key-value pairs', () => { - // Helper function to simulate removeExtraKeysFromFile logic - function removeExtraKeysFromFile(content: string, keysToRemove: string[]): string { - const lines = content.split('\n') - const linesToRemove: number[] = [] - - for (const keyToRemove of keysToRemove) { - let targetLineIndex = -1 - const linesToRemoveForKey: number[] = [] - - // Find the key line (simplified for single-level keys in test) - for (let i = 0; i < lines.length; i++) { - const line = lines[i] - const keyPattern = new RegExp(`^\\s*${keyToRemove}\\s*:`) - if (keyPattern.test(line)) { - targetLineIndex = i - break - } - } - - if (targetLineIndex !== -1) { - linesToRemoveForKey.push(targetLineIndex) - - // Check if this is a multiline key-value pair - const keyLine = lines[targetLineIndex] - const trimmedKeyLine = keyLine.trim() - - // If key line ends with ":" (not complete value), it's likely multiline - if (trimmedKeyLine.endsWith(':') && !trimmedKeyLine.includes('{') && !/:\s*['"`]/.exec(trimmedKeyLine)) { - // Find the value lines that belong to this key - let currentLine = targetLineIndex + 1 - let foundValue = false - - while (currentLine < lines.length) { - const line = lines[currentLine] - const trimmed = line.trim() - - // Skip empty lines - if (trimmed === '') { - currentLine++ - continue - } - - // Check if this line starts a new key (indicates end of current value) - if (/^\w+\s*:/.exec(trimmed)) - break - - // Check if this line is part of the value - if (trimmed.startsWith('\'') || trimmed.startsWith('"') || trimmed.startsWith('`') || foundValue) { - linesToRemoveForKey.push(currentLine) - foundValue = true - - // Check if this line ends the value (ends with quote and comma/no comma) - if ((trimmed.endsWith('\',') || trimmed.endsWith('",') || trimmed.endsWith('`,') - || trimmed.endsWith('\'') || trimmed.endsWith('"') || trimmed.endsWith('`')) - && !trimmed.startsWith('//')) { - break - } - } - else { - break - } - - currentLine++ - } - } - - linesToRemove.push(...linesToRemoveForKey) - } - } - - // Remove duplicates and sort in reverse order - const uniqueLinesToRemove = [...new Set(linesToRemove)].sort((a, b) => b - a) - - for (const lineIndex of uniqueLinesToRemove) - lines.splice(lineIndex, 1) - - return lines.join('\n') - } - - it('should remove single-line key-value pairs correctly', () => { - const content = `const translation = { - keepThis: 'This should stay', - removeThis: 'This should be removed', - alsoKeep: 'This should also stay', -} - -export default translation` - - const result = removeExtraKeysFromFile(content, ['removeThis']) - - expect(result).toContain('keepThis: \'This should stay\'') - expect(result).toContain('alsoKeep: \'This should also stay\'') - expect(result).not.toContain('removeThis: \'This should be removed\'') - }) - - it('should remove multiline key-value pairs completely', () => { - const content = `const translation = { - keepThis: 'This should stay', - removeMultiline: - 'This is a multiline value that should be removed completely', - alsoKeep: 'This should also stay', -} - -export default translation` - - const result = removeExtraKeysFromFile(content, ['removeMultiline']) - - expect(result).toContain('keepThis: \'This should stay\'') - expect(result).toContain('alsoKeep: \'This should also stay\'') - expect(result).not.toContain('removeMultiline:') - expect(result).not.toContain('This is a multiline value that should be removed completely') - }) - - it('should handle mixed single-line and multiline removals', () => { - const content = `const translation = { - keepThis: 'Keep this', - removeSingle: 'Remove this single line', - removeMultiline: - 'Remove this multiline value', - anotherMultiline: - 'Another multiline that spans multiple lines', - keepAnother: 'Keep this too', -} - -export default translation` - - const result = removeExtraKeysFromFile(content, ['removeSingle', 'removeMultiline', 'anotherMultiline']) - - expect(result).toContain('keepThis: \'Keep this\'') - expect(result).toContain('keepAnother: \'Keep this too\'') - expect(result).not.toContain('removeSingle:') - expect(result).not.toContain('removeMultiline:') - expect(result).not.toContain('anotherMultiline:') - expect(result).not.toContain('Remove this single line') - expect(result).not.toContain('Remove this multiline value') - expect(result).not.toContain('Another multiline that spans multiple lines') - }) - - it('should properly detect multiline vs single-line patterns', () => { - const multilineContent = `const translation = { - singleLine: 'This is single line', - multilineKey: - 'This is multiline', - keyWithColon: 'Value with: colon inside', - objectKey: { - nested: 'value' - }, -} - -export default translation` - - // Test that single line with colon in value is not treated as multiline - const result1 = removeExtraKeysFromFile(multilineContent, ['keyWithColon']) - expect(result1).not.toContain('keyWithColon:') - expect(result1).not.toContain('Value with: colon inside') - - // Test that true multiline is handled correctly - const result2 = removeExtraKeysFromFile(multilineContent, ['multilineKey']) - expect(result2).not.toContain('multilineKey:') - expect(result2).not.toContain('This is multiline') - - // Test that object key removal works (note: this is a simplified test) - // In real scenario, object removal would be more complex - const result3 = removeExtraKeysFromFile(multilineContent, ['objectKey']) - expect(result3).not.toContain('objectKey: {') - // Note: Our simplified test function doesn't handle nested object removal perfectly - // This is acceptable as it's testing the main multiline string removal functionality - }) - - it('should handle real-world Polish translation structure', () => { - const polishContent = `const translation = { - createApp: 'UTWÓRZ APLIKACJĘ', - newApp: { - captionAppType: 'Jaki typ aplikacji chcesz stworzyć?', - chatbotDescription: - 'Zbuduj aplikację opartą na czacie. Ta aplikacja używa formatu pytań i odpowiedzi.', - agentDescription: - 'Zbuduj inteligentnego agenta, który może autonomicznie wybierać narzędzia.', - basic: 'Podstawowy', - }, -} - -export default translation` - - const result = removeExtraKeysFromFile(polishContent, ['captionAppType', 'chatbotDescription', 'agentDescription']) - - expect(result).toContain('createApp: \'UTWÓRZ APLIKACJĘ\'') - expect(result).toContain('basic: \'Podstawowy\'') - expect(result).not.toContain('captionAppType:') - expect(result).not.toContain('chatbotDescription:') - expect(result).not.toContain('agentDescription:') - expect(result).not.toContain('Jaki typ aplikacji') - expect(result).not.toContain('Zbuduj aplikację opartą na czacie') - expect(result).not.toContain('Zbuduj inteligentnego agenta') - }) - }) - - describe('Performance and Scalability', () => { - it('should handle large translation files efficiently', async () => { - // Create a large translation file with 1000 keys - const largeContent = `const translation = { -${Array.from({ length: 1000 }, (_, i) => ` key${i}: 'value${i}',`).join('\n')} -} - -export default translation` - - fs.writeFileSync(path.join(testEnDir, 'large.ts'), largeContent) - - const startTime = Date.now() - const keys = await getKeysFromLanguage('en-US') - const endTime = Date.now() - - expect(keys.length).toBe(1000) - expect(endTime - startTime).toBeLessThan(10000) - }) - - it('should handle multiple translation files concurrently', async () => { - // Create multiple files - for (let i = 0; i < 10; i++) { - const content = `const translation = { - key${i}: 'value${i}', - nested${i}: { - subkey: 'subvalue' - } -} - -export default translation` - fs.writeFileSync(path.join(testEnDir, `file${i}.ts`), content) - } - - const startTime = Date.now() - const keys = await getKeysFromLanguage('en-US') - const endTime = Date.now() - - expect(keys.length).toBe(20) // 10 files * 2 keys each - expect(endTime - startTime).toBeLessThan(10000) - }) - }) - - describe('Unicode and Internationalization', () => { - it('should handle Unicode characters in keys and values', async () => { - const unicodeContent = `const translation = { - '中文键': '中文值', - 'العربية': 'قيمة', - 'emoji_😀': 'value with emoji 🎉', - 'mixed_中文_English': 'mixed value' -} - -export default translation` - - fs.writeFileSync(path.join(testEnDir, 'unicode.ts'), unicodeContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('unicode.中文键') - expect(keys).toContain('unicode.العربية') - expect(keys).toContain('unicode.emoji_😀') - expect(keys).toContain('unicode.mixed_中文_English') - }) - - it('should handle RTL language files', async () => { - const rtlContent = `const translation = { - مرحبا: 'Hello', - العالم: 'World', - nested: { - مفتاح: 'key' - } -} - -export default translation` - - fs.writeFileSync(path.join(testEnDir, 'rtl.ts'), rtlContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('rtl.مرحبا') - expect(keys).toContain('rtl.العالم') - expect(keys).toContain('rtl.nested.مفتاح') - }) - }) - - describe('Error Recovery', () => { - it('should handle syntax errors in translation files gracefully', async () => { - const invalidContent = `const translation = { - validKey: 'valid value', - invalidKey: 'missing quote, - anotherKey: 'another value' -} - -export default translation` - - fs.writeFileSync(path.join(testEnDir, 'invalid.ts'), invalidContent) - - await expect(getKeysFromLanguage('en-US')).rejects.toThrow() - }) - }) -}) diff --git a/web/app/components/app/app-access-control/access-control.spec.tsx b/web/app/components/app/app-access-control/access-control.spec.tsx index 3950bdf7ee..3a5f2272ed 100644 --- a/web/app/components/app/app-access-control/access-control.spec.tsx +++ b/web/app/components/app/app-access-control/access-control.spec.tsx @@ -109,7 +109,7 @@ beforeAll(() => { disconnect = vi.fn(() => undefined) unobserve = vi.fn(() => undefined) } - // @ts-expect-error jsdom does not implement IntersectionObserver + // @ts-expect-error test DOM typings do not guarantee IntersectionObserver here globalThis.IntersectionObserver = MockIntersectionObserver }) diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx index 188086246a..389ab189e9 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx @@ -556,8 +556,8 @@ describe('DebugWithMultipleModel', () => { ) const twoItems = screen.getAllByTestId('debug-item') - expect(twoItems[0].style.width).toBe('calc(50% - 28px)') - expect(twoItems[1].style.width).toBe('calc(50% - 28px)') + expect(twoItems[0].style.width).toBe('calc(50% - 4px - 24px)') + expect(twoItems[1].style.width).toBe('calc(50% - 4px - 24px)') }) }) @@ -596,13 +596,13 @@ describe('DebugWithMultipleModel', () => { // Assert expect(items).toHaveLength(2) expectItemLayout(items[0], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: '100%', transform: 'translateX(0) translateY(0)', classes: ['mr-2'], }) expectItemLayout(items[1], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: '100%', transform: 'translateX(calc(100% + 8px)) translateY(0)', classes: [], @@ -620,19 +620,19 @@ describe('DebugWithMultipleModel', () => { // Assert expect(items).toHaveLength(3) expectItemLayout(items[0], { - width: 'calc(33.3% - 21.33px)', + width: 'calc(33.3% - 5.33px - 16px)', height: '100%', transform: 'translateX(0) translateY(0)', classes: ['mr-2'], }) expectItemLayout(items[1], { - width: 'calc(33.3% - 21.33px)', + width: 'calc(33.3% - 5.33px - 16px)', height: '100%', transform: 'translateX(calc(100% + 8px)) translateY(0)', classes: ['mr-2'], }) expectItemLayout(items[2], { - width: 'calc(33.3% - 21.33px)', + width: 'calc(33.3% - 5.33px - 16px)', height: '100%', transform: 'translateX(calc(200% + 16px)) translateY(0)', classes: [], @@ -655,25 +655,25 @@ describe('DebugWithMultipleModel', () => { // Assert expect(items).toHaveLength(4) expectItemLayout(items[0], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(0) translateY(0)', classes: ['mr-2', 'mb-2'], }) expectItemLayout(items[1], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(calc(100% + 8px)) translateY(0)', classes: ['mb-2'], }) expectItemLayout(items[2], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(0) translateY(calc(100% + 8px))', classes: ['mr-2'], }) expectItemLayout(items[3], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(calc(100% + 8px)) translateY(calc(100% + 8px))', classes: [], diff --git a/web/app/components/app/overview/settings/index.spec.tsx b/web/app/components/app/overview/settings/index.spec.tsx index b849b4f015..e933855ca8 100644 --- a/web/app/components/app/overview/settings/index.spec.tsx +++ b/web/app/components/app/overview/settings/index.spec.tsx @@ -1,6 +1,3 @@ -/** - * @vitest-environment jsdom - */ import type { ReactNode } from 'react' import type { ModalContextState } from '@/context/modal-context' import type { ProviderContextState } from '@/context/provider-context' diff --git a/web/app/components/base/action-button/__tests__/index.spec.tsx b/web/app/components/base/action-button/__tests__/index.spec.tsx index 949a980272..e9db157d0c 100644 --- a/web/app/components/base/action-button/__tests__/index.spec.tsx +++ b/web/app/components/base/action-button/__tests__/index.spec.tsx @@ -62,8 +62,8 @@ describe('ActionButton', () => { ) const button = screen.getByRole('button', { name: 'Custom Style' }) expect(button).toHaveStyle({ - color: 'rgb(255, 0, 0)', - backgroundColor: 'rgb(0, 0, 255)', + color: 'red', + backgroundColor: 'blue', }) }) diff --git a/web/app/components/base/chat/chat/__tests__/index.spec.tsx b/web/app/components/base/chat/chat/__tests__/index.spec.tsx index 781b5e86f3..0100b059f0 100644 --- a/web/app/components/base/chat/chat/__tests__/index.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/index.spec.tsx @@ -8,10 +8,10 @@ import Chat from '../index' // ─── Why each mock exists ───────────────────────────────────────────────────── // // Answer – transitively pulls Markdown (rehype/remark/katex), AgentContent, -// WorkflowProcessItem and Operation; none can resolve in jsdom. +// WorkflowProcessItem and Operation; none can resolve in the test DOM runtime. // Question – pulls Markdown, copy-to-clipboard, react-textarea-autosize. // ChatInputArea – pulls js-audio-recorder (requires Web Audio API unavailable in -// jsdom) and VoiceInput / FileContextProvider chains. +// the test DOM runtime) and VoiceInput / FileContextProvider chains. // PromptLogModal– pulls CopyFeedbackNew and deep modal dep chain. // AgentLogModal – pulls @remixicon/react (causes lint push error), useClickAway // from ahooks, and AgentLogDetail (workflow graph renderer). diff --git a/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx b/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx index d8e00780b1..8839798c15 100644 --- a/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx +++ b/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx @@ -3,7 +3,7 @@ import { fireEvent, render, screen } from '@testing-library/react' import dayjs from '../../utils/dayjs' import Calendar from '../index' -// Mock scrollIntoView since jsdom doesn't implement it +// Mock scrollIntoView since the test DOM runtime doesn't implement it beforeAll(() => { Element.prototype.scrollIntoView = vi.fn() }) diff --git a/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx b/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx index 910faf9cd4..199ed4ee41 100644 --- a/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx +++ b/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx @@ -3,7 +3,7 @@ import { fireEvent, render, screen, within } from '@testing-library/react' import dayjs, { isDayjsObject } from '../../utils/dayjs' import TimePicker from '../index' -// Mock scrollIntoView since jsdom doesn't implement it +// Mock scrollIntoView since the test DOM runtime doesn't implement it beforeAll(() => { Element.prototype.scrollIntoView = vi.fn() }) diff --git a/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx b/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx index 69496903a6..de9cc7ecd0 100644 --- a/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx +++ b/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx @@ -37,11 +37,11 @@ const FileFromLinkOrLocal = ({ const { handleLoadFileFromLink } = useFile(fileConfig) const disabled = !!fileConfig.number_limits && files.length >= fileConfig.number_limits const fileLinkPlaceholder = t('fileUploader.pasteFileLinkInputPlaceholder', { ns: 'common' }) - /* v8 ignore next -- fallback for missing i18n key is not reliably testable under current global translation mocks in jsdom @preserve */ + /* v8 ignore next -- fallback for a missing i18n key is not reliably testable under the current global translation mocks in the test DOM runtime. @preserve */ const fileLinkPlaceholderText = fileLinkPlaceholder || '' const handleSaveUrl = () => { - /* v8 ignore next -- guarded by UI-level disabled state (`disabled={!url || disabled}`), not reachable in jsdom click flow @preserve */ + /* v8 ignore next -- guarded by UI-level disabled state (`disabled={!url || disabled}`), not reachable in the current test click flow. @preserve */ if (!url) return diff --git a/web/app/components/base/icons/__tests__/utils.spec.ts b/web/app/components/base/icons/__tests__/utils.spec.ts index a25f39111d..f8534038bf 100644 --- a/web/app/components/base/icons/__tests__/utils.spec.ts +++ b/web/app/components/base/icons/__tests__/utils.spec.ts @@ -62,7 +62,7 @@ describe('generate icon base utils', () => { const { container } = render(generate(node, 'key')) // to svg element expect(container.firstChild).toHaveClass('container') - expect(container.querySelector('span')).toHaveStyle({ color: 'rgb(0, 0, 255)' }) + expect(container.querySelector('span')).toHaveStyle({ color: 'blue' }) }) // add not has children diff --git a/web/app/components/base/input/__tests__/index.spec.tsx b/web/app/components/base/input/__tests__/index.spec.tsx index 2c5b563a12..dfab8617c2 100644 --- a/web/app/components/base/input/__tests__/index.spec.tsx +++ b/web/app/components/base/input/__tests__/index.spec.tsx @@ -99,7 +99,7 @@ describe('Input component', () => { render() const input = screen.getByPlaceholderText(/input/i) expect(input).toHaveClass(customClass) - expect(input).toHaveStyle({ color: 'rgb(255, 0, 0)' }) + expect(input).toHaveStyle({ color: 'red' }) }) it('applies large size variant correctly', () => { diff --git a/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx b/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx index 745b7657d7..a16686801c 100644 --- a/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx +++ b/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx @@ -1,6 +1,6 @@ -import { createRequire } from 'node:module' import { act, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' +import * as echarts from 'echarts' import { Theme } from '@/types/app' import CodeBlock from '../code-block' @@ -10,12 +10,21 @@ type UseThemeReturn = { } const mockUseTheme = vi.fn<() => UseThemeReturn>(() => ({ theme: Theme.light })) -const require = createRequire(import.meta.url) -const echartsCjs = require('echarts') as { - getInstanceByDom: (dom: HTMLDivElement | null) => { - resize: (opts?: { width?: string, height?: string }) => void - } | null -} +const mockEcharts = vi.hoisted(() => { + const state = { + finishedHandler: undefined as undefined | ((event?: unknown) => void), + echartsInstance: { + resize: vi.fn<(opts?: { width?: string, height?: string }) => void>(), + trigger: vi.fn((eventName: string, event?: unknown) => { + if (eventName === 'finished') + state.finishedHandler?.(event) + }), + }, + getInstanceByDom: vi.fn(() => state.echartsInstance), + } + + return state +}) let clientWidthSpy: { mockRestore: () => void } | null = null let clientHeightSpy: { mockRestore: () => void } | null = null @@ -61,6 +70,42 @@ vi.mock('@/hooks/use-theme', () => ({ default: () => mockUseTheme(), })) +vi.mock('echarts', () => ({ + getInstanceByDom: mockEcharts.getInstanceByDom, +})) + +vi.mock('echarts-for-react', async () => { + const React = await vi.importActual('react') + + const MockReactEcharts = React.forwardRef(({ + onChartReady, + onEvents, + }: { + onChartReady?: (instance: typeof mockEcharts.echartsInstance) => void + onEvents?: { finished?: (event?: unknown) => void } + }, ref: React.ForwardedRef<{ getEchartsInstance: () => typeof mockEcharts.echartsInstance }>) => { + React.useImperativeHandle(ref, () => ({ + getEchartsInstance: () => mockEcharts.echartsInstance, + })) + + React.useEffect(() => { + mockEcharts.finishedHandler = onEvents?.finished + onChartReady?.(mockEcharts.echartsInstance) + onEvents?.finished?.({}) + return () => { + mockEcharts.finishedHandler = undefined + } + }, [onChartReady, onEvents]) + + return
+ }) + + return { + __esModule: true, + default: MockReactEcharts, + } +}) + vi.mock('@/app/components/base/mermaid', () => ({ __esModule: true, default: ({ PrimitiveCode }: { PrimitiveCode: string }) =>
{PrimitiveCode}
, @@ -76,9 +121,9 @@ const findEchartsHost = async () => { const findEchartsInstance = async () => { const host = await findEchartsHost() await waitFor(() => { - expect(echartsCjs.getInstanceByDom(host)).toBeTruthy() + expect(echarts.getInstanceByDom(host)).toBeTruthy() }) - return echartsCjs.getInstanceByDom(host)! + return echarts.getInstanceByDom(host)! } describe('CodeBlock', () => { diff --git a/web/app/components/base/node-status/__tests__/index.spec.tsx b/web/app/components/base/node-status/__tests__/index.spec.tsx index f74af4965e..37b12946c8 100644 --- a/web/app/components/base/node-status/__tests__/index.spec.tsx +++ b/web/app/components/base/node-status/__tests__/index.spec.tsx @@ -41,7 +41,7 @@ describe('NodeStatus', () => { it('applies styleCss correctly', () => { const { container } = render() - expect(container.firstChild).toHaveStyle({ color: 'rgb(255, 0, 0)' }) + expect(container.firstChild).toHaveStyle({ color: 'red' }) }) it('applies iconClassName to the icon', () => { diff --git a/web/app/components/base/pagination/__tests__/pagination.spec.tsx b/web/app/components/base/pagination/__tests__/pagination.spec.tsx index 776802ff19..06eac9bfbd 100644 --- a/web/app/components/base/pagination/__tests__/pagination.spec.tsx +++ b/web/app/components/base/pagination/__tests__/pagination.spec.tsx @@ -131,7 +131,7 @@ describe('Pagination', () => { setCurrentPage, children: Prev, }) - fireEvent.keyPress(screen.getByText(/prev/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/prev/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).toHaveBeenCalledWith(2) }) @@ -142,7 +142,7 @@ describe('Pagination', () => { setCurrentPage, children: Prev, }) - fireEvent.keyPress(screen.getByText(/prev/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/prev/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).not.toHaveBeenCalled() }) @@ -213,7 +213,7 @@ describe('Pagination', () => { setCurrentPage, children: Next, }) - fireEvent.keyPress(screen.getByText(/next/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/next/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).toHaveBeenCalledWith(1) }) @@ -225,7 +225,7 @@ describe('Pagination', () => { setCurrentPage, children: Next, }) - fireEvent.keyPress(screen.getByText(/next/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/next/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).not.toHaveBeenCalled() }) @@ -318,7 +318,7 @@ describe('Pagination', () => { /> ), }) - fireEvent.keyPress(screen.getByText('4'), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText('4').closest('a')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).toHaveBeenCalledWith(3) // 0-indexed }) diff --git a/web/app/components/base/pagination/pagination.tsx b/web/app/components/base/pagination/pagination.tsx index 0eb06b594c..b258090d80 100644 --- a/web/app/components/base/pagination/pagination.tsx +++ b/web/app/components/base/pagination/pagination.tsx @@ -50,7 +50,7 @@ export const PrevButton = ({ tabIndex={disabled ? '-1' : 0} disabled={disabled} data-testid={dataTestId} - onKeyPress={(event: React.KeyboardEvent) => { + onKeyDown={(event: React.KeyboardEvent) => { event.preventDefault() if (event.key === 'Enter' && !disabled) previous() @@ -85,7 +85,7 @@ export const NextButton = ({ tabIndex={disabled ? '-1' : 0} disabled={disabled} data-testid={dataTestId} - onKeyPress={(event: React.KeyboardEvent) => { + onKeyDown={(event: React.KeyboardEvent) => { event.preventDefault() if (event.key === 'Enter' && !disabled) next() @@ -140,7 +140,7 @@ export const PageButton = ({ }) || undefined } tabIndex={0} - onKeyPress={(event: React.KeyboardEvent) => { + onKeyDown={(event: React.KeyboardEvent) => { if (event.key === 'Enter') pagination.setCurrentPage(page - 1) }} diff --git a/web/app/components/base/premium-badge/__tests__/index.spec.tsx b/web/app/components/base/premium-badge/__tests__/index.spec.tsx index af8ace22f0..d107c07e52 100644 --- a/web/app/components/base/premium-badge/__tests__/index.spec.tsx +++ b/web/app/components/base/premium-badge/__tests__/index.spec.tsx @@ -41,6 +41,6 @@ describe('PremiumBadge', () => { ) const badge = screen.getByText('Premium') expect(badge).toBeInTheDocument() - expect(badge).toHaveStyle('background-color: rgb(255, 0, 0)') // Note: React converts 'red' to 'rgb(255, 0, 0)' + expect(badge).toHaveStyle('background-color: red') }) }) diff --git a/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx b/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx index abe6ea9a45..7dcda803f2 100644 --- a/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx @@ -141,7 +141,7 @@ export default function ShortcutsPopupPlugin({ const portalRef = useRef(null) const lastSelectionRef = useRef(null) - /* v8 ignore next -- defensive non-browser fallback; this client-only plugin runs where document exists (browser/jsdom). @preserve */ + /* v8 ignore next -- defensive non-browser fallback; this client-only plugin runs where document exists (browser/test DOM runtime). @preserve */ const containerEl = useMemo(() => container ?? (typeof document !== 'undefined' ? document.body : null), [container]) const useContainer = !!containerEl && containerEl !== document.body @@ -210,7 +210,7 @@ export default function ShortcutsPopupPlugin({ if (rect.width === 0 && rect.height === 0) { const root = editor.getRootElement() - /* v8 ignore next 10 -- zero-size rect recovery depends on browser layout/selection geometry; deterministic reproduction in jsdom is unreliable. @preserve */ + /* v8 ignore next 10 -- zero-size rect recovery depends on browser layout/selection geometry; deterministic reproduction in the test DOM runtime is unreliable. @preserve */ if (root) { const sc = range.startContainer const node = sc.nodeType === Node.ELEMENT_NODE diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx index 7f292c8ff9..0a3470420c 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx @@ -1612,9 +1612,7 @@ describe('Uploader', () => { if (!dropArea) return - fireEvent.drop(dropArea, { - dataTransfer: null, - }) + fireEvent.drop(dropArea) expect(updateFile).not.toHaveBeenCalled() }) diff --git a/web/app/components/datasets/create/website/watercrawl/__tests__/index.spec.tsx b/web/app/components/datasets/create/website/watercrawl/__tests__/index.spec.tsx index 5ff2d8efb8..e9d933cc03 100644 --- a/web/app/components/datasets/create/website/watercrawl/__tests__/index.spec.tsx +++ b/web/app/components/datasets/create/website/watercrawl/__tests__/index.spec.tsx @@ -1,6 +1,3 @@ -/** - * @vitest-environment jsdom - */ import type { Mock } from 'vitest' import type { CrawlOptions, CrawlResultItem } from '@/models/datasets' import { fireEvent, render, screen, waitFor } from '@testing-library/react' diff --git a/web/app/components/develop/__tests__/code.spec.tsx b/web/app/components/develop/__tests__/code.spec.tsx index 452e6ea98f..e5eaebb600 100644 --- a/web/app/components/develop/__tests__/code.spec.tsx +++ b/web/app/components/develop/__tests__/code.spec.tsx @@ -11,7 +11,7 @@ describe('code.tsx components', () => { vi.clearAllMocks() vi.spyOn(console, 'error').mockImplementation(() => {}) vi.useFakeTimers({ shouldAdvanceTime: true }) - // jsdom does not implement scrollBy; mock it to prevent stderr noise + // The test DOM runtime does not implement scrollBy; mock it to prevent stderr noise window.scrollBy = vi.fn() }) diff --git a/web/app/components/develop/__tests__/use-doc-toc.spec.ts b/web/app/components/develop/__tests__/use-doc-toc.spec.ts index e437e13065..b20c2c8ecf 100644 --- a/web/app/components/develop/__tests__/use-doc-toc.spec.ts +++ b/web/app/components/develop/__tests__/use-doc-toc.spec.ts @@ -307,7 +307,7 @@ describe('useDocToc', () => { it('should update activeSection when scrolling past a section', async () => { vi.useFakeTimers() - // innerHeight/2 = 384 in jsdom (default 768), so top <= 384 means "scrolled past" + // innerHeight/2 = 384 with the default test viewport height (768), so top <= 384 means "scrolled past" const { scrollContainer, cleanup } = setupScrollDOM([ { id: 'intro', text: 'Intro', top: 100 }, { id: 'details', text: 'Details', top: 600 }, diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/__tests__/add-custom-model.spec.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/__tests__/add-custom-model.spec.tsx index 6117420afa..43a27dac9b 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/__tests__/add-custom-model.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/__tests__/add-custom-model.spec.tsx @@ -43,7 +43,7 @@ vi.mock('@/app/components/base/tooltip', () => ({ ), })) -// Mock portal components to avoid async/jsdom issues (consistent with sibling tests) +// Mock portal components to avoid async test DOM issues (consistent with sibling tests) vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean, onOpenChange: (open: boolean) => void }) => (
diff --git a/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx b/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx index 237c72adf0..2af14c5864 100644 --- a/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx @@ -142,7 +142,7 @@ describe('EndpointCard', () => { failureFlags.disable = false failureFlags.delete = false failureFlags.update = false - // Polyfill document.execCommand for copy-to-clipboard in jsdom + // Polyfill document.execCommand for copy-to-clipboard in the test DOM runtime if (typeof document.execCommand !== 'function') { document.execCommand = vi.fn().mockReturnValue(true) } diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx index ce53bf5b9a..5c4407b3c5 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx @@ -102,10 +102,12 @@ vi.mock('@/app/components/base/ui/toast', () => ({ })) const mockClipboardWriteText = vi.fn() -Object.assign(navigator, { - clipboard: { +Object.defineProperty(navigator, 'clipboard', { + value: { writeText: mockClipboardWriteText, }, + configurable: true, + writable: true, }) vi.mock('@/app/components/base/modal/modal', () => ({ @@ -192,6 +194,13 @@ describe('OAuthClientSettingsModal', () => { vi.clearAllMocks() mockUsePluginStore.mockReturnValue(mockPluginDetail) mockClipboardWriteText.mockResolvedValue(undefined) + Object.defineProperty(navigator, 'clipboard', { + value: { + writeText: mockClipboardWriteText, + }, + configurable: true, + writable: true, + }) setMockFormValues({ values: { client_id: 'test-client-id', client_secret: 'test-client-secret' }, isCheckValidated: true, diff --git a/web/app/components/workflow/__tests__/custom-edge-linear-gradient-render.spec.tsx b/web/app/components/workflow/__tests__/custom-edge-linear-gradient-render.spec.tsx index e962923158..973dfacbc8 100644 --- a/web/app/components/workflow/__tests__/custom-edge-linear-gradient-render.spec.tsx +++ b/web/app/components/workflow/__tests__/custom-edge-linear-gradient-render.spec.tsx @@ -48,10 +48,10 @@ describe('CustomEdgeLinearGradientRender', () => { const stops = container.querySelectorAll('stop') expect(stops).toHaveLength(2) expect(stops[0]).toHaveAttribute('offset', '0%') - expect(stops[0].getAttribute('style')).toContain('stop-color: rgb(17, 17, 17)') + expect(stops[0].getAttribute('style')).toContain('stop-color: #111111') expect(stops[0].getAttribute('style')).toContain('stop-opacity: 1') expect(stops[1]).toHaveAttribute('offset', '100%') - expect(stops[1].getAttribute('style')).toContain('stop-color: rgb(34, 34, 34)') + expect(stops[1].getAttribute('style')).toContain('stop-color: #222222') expect(stops[1].getAttribute('style')).toContain('stop-opacity: 1') }) }) diff --git a/web/app/components/workflow/__tests__/selection-contextmenu.spec.tsx b/web/app/components/workflow/__tests__/selection-contextmenu.spec.tsx index b153eb8b8a..1106cfcb75 100644 --- a/web/app/components/workflow/__tests__/selection-contextmenu.spec.tsx +++ b/web/app/components/workflow/__tests__/selection-contextmenu.spec.tsx @@ -10,6 +10,9 @@ import { renderWorkflowFlowComponent } from './workflow-test-env' let latestNodes: Node[] = [] let latestHistoryEvent: string | undefined const mockGetNodesReadOnly = vi.fn() +const mockHandleNodesCopy = vi.fn() +const mockHandleNodesDuplicate = vi.fn() +const mockHandleNodesDelete = vi.fn() vi.mock('../hooks', async () => { const actual = await vi.importActual('../hooks') @@ -18,6 +21,11 @@ vi.mock('../hooks', async () => { useNodesReadOnly: () => ({ getNodesReadOnly: mockGetNodesReadOnly, }), + useNodesInteractions: () => ({ + handleNodesCopy: mockHandleNodesCopy, + handleNodesDuplicate: mockHandleNodesDuplicate, + handleNodesDelete: mockHandleNodesDelete, + }), } }) @@ -73,6 +81,9 @@ describe('SelectionContextmenu', () => { latestHistoryEvent = undefined mockGetNodesReadOnly.mockReset() mockGetNodesReadOnly.mockReturnValue(false) + mockHandleNodesCopy.mockReset() + mockHandleNodesDuplicate.mockReset() + mockHandleNodesDelete.mockReset() }) it('should not render when selectionMenu is absent', () => { @@ -97,6 +108,40 @@ describe('SelectionContextmenu', () => { }) }) + it('should render and execute copy/duplicate/delete operations', async () => { + const nodes = [ + createNode({ id: 'n1', selected: true, width: 80, height: 40 }), + createNode({ id: 'n2', selected: true, position: { x: 140, y: 0 }, width: 80, height: 40 }), + ] + const { store } = renderSelectionMenu({ nodes }) + + act(() => { + store.setState({ selectionMenu: { clientX: 120, clientY: 120 } }) + }) + + await waitFor(() => { + expect(screen.getByTestId('selection-contextmenu-item-copy')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('selection-contextmenu-item-copy')) + expect(mockHandleNodesCopy).toHaveBeenCalledTimes(1) + expect(store.getState().selectionMenu).toBeUndefined() + + act(() => { + store.setState({ selectionMenu: { clientX: 120, clientY: 120 } }) + }) + fireEvent.click(screen.getByTestId('selection-contextmenu-item-duplicate')) + expect(mockHandleNodesDuplicate).toHaveBeenCalledTimes(1) + expect(store.getState().selectionMenu).toBeUndefined() + + act(() => { + store.setState({ selectionMenu: { clientX: 120, clientY: 120 } }) + }) + fireEvent.click(screen.getByTestId('selection-contextmenu-item-delete')) + expect(mockHandleNodesDelete).toHaveBeenCalledTimes(1) + expect(store.getState().selectionMenu).toBeUndefined() + }) + it('should close itself when only one node is selected', async () => { const nodes = [ createNode({ id: 'n1', selected: true, width: 80, height: 40 }), diff --git a/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx b/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx index 82645f2028..961ab6ddb4 100644 --- a/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx +++ b/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx @@ -209,7 +209,7 @@ describe('UpdateDSLModal', () => { }) await waitFor(() => { - expect(screen.getByRole('button', { name: 'app.newApp.Cancel' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'app.newApp.Confirm' })).toBeInTheDocument() }, { timeout: 1000 }) fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Cancel' })) diff --git a/web/app/components/workflow/nodes/trigger-schedule/components/__tests__/integration.spec.tsx b/web/app/components/workflow/nodes/trigger-schedule/components/__tests__/integration.spec.tsx index 00a6cbbe29..f2441e78d0 100644 --- a/web/app/components/workflow/nodes/trigger-schedule/components/__tests__/integration.spec.tsx +++ b/web/app/components/workflow/nodes/trigger-schedule/components/__tests__/integration.spec.tsx @@ -1,6 +1,6 @@ /* eslint-disable ts/no-explicit-any */ import type { ScheduleTriggerNodeType } from '../../types' -import { fireEvent, render, screen, waitFor, within } from '@testing-library/react' +import { render, screen, waitFor, within } from '@testing-library/react' import userEvent from '@testing-library/user-event' import FrequencySelector from '../frequency-selector' import ModeSwitcher from '../mode-switcher' @@ -44,14 +44,14 @@ describe('trigger-schedule components', () => { ) const trigger = screen.getByRole('button', { name: 'workflow.nodes.triggerSchedule.frequency.daily' }) - fireEvent.click(trigger) + await user.click(trigger) await waitFor(() => { expect(trigger).toHaveAttribute('aria-expanded', 'true') }) const listbox = await screen.findByRole('listbox') - await user.click(within(listbox).getByText('workflow.nodes.triggerSchedule.frequency.weekly')) + await user.click(within(listbox).getByRole('option', { name: 'workflow.nodes.triggerSchedule.frequency.weekly' })) await waitFor(() => { expect(onChange).toHaveBeenCalledWith('weekly') diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/__tests__/variable-modal.spec.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/__tests__/variable-modal.spec.tsx index 319e3803f4..297b534a6a 100644 --- a/web/app/components/workflow/panel/chat-variable-panel/components/__tests__/variable-modal.spec.tsx +++ b/web/app/components/workflow/panel/chat-variable-panel/components/__tests__/variable-modal.spec.tsx @@ -150,7 +150,7 @@ describe('variable-modal', () => { await user.click(screen.getByText('workflow.chatVariable.modal.editInJSON')) await waitFor(() => { - expect(screen.getByText('Loading...')).toBeInTheDocument() + expect(screen.getByTestId('monaco-editor')).toBeInTheDocument() }) await user.click(screen.getByText('workflow.chatVariable.modal.editInForm')) expect(screen.getByDisplayValue('enabled')).toBeInTheDocument() diff --git a/web/app/components/workflow/selection-contextmenu.tsx b/web/app/components/workflow/selection-contextmenu.tsx index c13d881cc2..2b22df5012 100644 --- a/web/app/components/workflow/selection-contextmenu.tsx +++ b/web/app/components/workflow/selection-contextmenu.tsx @@ -16,9 +16,10 @@ import { ContextMenuItem, ContextMenuSeparator, } from '@/app/components/base/ui/context-menu' -import { useNodesReadOnly, useNodesSyncDraft } from './hooks' +import { useNodesInteractions, useNodesReadOnly, useNodesSyncDraft } from './hooks' import { useSelectionInteractions } from './hooks/use-selection-interactions' import { useWorkflowHistory, WorkflowHistoryEvent } from './hooks/use-workflow-history' +import ShortcutsName from './shortcuts-name' import { useStore, useWorkflowStore } from './store' const AlignType = { @@ -223,6 +224,7 @@ const SelectionContextmenu = () => { const { t } = useTranslation() const { getNodesReadOnly } = useNodesReadOnly() const { handleSelectionContextmenuCancel } = useSelectionInteractions() + const { handleNodesCopy, handleNodesDelete, handleNodesDuplicate } = useNodesInteractions() const selectionMenu = useStore(s => s.selectionMenu) const store = useStoreApi() const workflowStore = useWorkflowStore() @@ -251,6 +253,21 @@ const SelectionContextmenu = () => { handleSelectionContextmenuCancel() }, [selectionMenu, selectedNodes.length, handleSelectionContextmenuCancel]) + const handleCopyNodes = useCallback(() => { + handleNodesCopy() + handleSelectionContextmenuCancel() + }, [handleNodesCopy, handleSelectionContextmenuCancel]) + + const handleDuplicateNodes = useCallback(() => { + handleNodesDuplicate() + handleSelectionContextmenuCancel() + }, [handleNodesDuplicate, handleSelectionContextmenuCancel]) + + const handleDeleteNodes = useCallback(() => { + handleNodesDelete() + handleSelectionContextmenuCancel() + }, [handleNodesDelete, handleSelectionContextmenuCancel]) + const handleAlignNodes = useCallback((alignType: AlignTypeValue) => { if (getNodesReadOnly() || selectedNodes.length <= 1) { handleSelectionContextmenuCancel() @@ -329,6 +346,36 @@ const SelectionContextmenu = () => { popupClassName="w-[240px]" positionerProps={anchor ? { anchor } : undefined} > + + + {t('common.copy', { defaultValue: 'common.copy', ns: 'workflow' })} + + + + {t('common.duplicate', { defaultValue: 'common.duplicate', ns: 'workflow' })} + + + + + + + {t('operation.delete', { defaultValue: 'operation.delete', ns: 'common' })} + + + + {menuSections.map((section, sectionIndex) => ( {sectionIndex > 0 && } diff --git a/web/docs/test.md b/web/docs/test.md index cb22b73b15..bc1546a991 100644 --- a/web/docs/test.md +++ b/web/docs/test.md @@ -8,7 +8,7 @@ When I ask you to write/refactor/fix tests, follow these rules by default. - **Framework**: Next.js 15 + React 19 + TypeScript - **Testing Tools**: Vitest 4.0.16 + React Testing Library 16.0 -- **Test Environment**: jsdom +- **Test Environment**: happy-dom - **File Naming**: `ComponentName.spec.tsx` inside a same-level `__tests__/` directory - **Placement Rule**: Component, hook, and utility tests must live in a sibling `__tests__/` folder at the same level as the source under test. For example, `foo/index.tsx` maps to `foo/__tests__/index.spec.tsx`, and `foo/bar.ts` maps to `foo/__tests__/bar.spec.ts`. @@ -30,7 +30,7 @@ pnpm test path/to/file.spec.tsx ## Project Test Setup -- **Configuration**: `vitest.config.ts` sets the `jsdom` environment, loads the Testing Library presets, and respects our path aliases (`@/...`). Check this file before adding new transformers or module name mappers. +- **Configuration**: `vite.config.ts` sets the `happy-dom` environment, loads the Testing Library presets, and respects our path aliases (`@/...`). Check this file before adding new transformers or module name mappers. - **Global setup**: `vitest.setup.ts` already imports `@testing-library/jest-dom`, runs `cleanup()` after every test, and defines shared mocks (for example `react-i18next`). Add any environment-level mocks (for example `ResizeObserver`, `matchMedia`, `IntersectionObserver`, `TextEncoder`, `crypto`) here so they are shared consistently. - **Reusable mocks**: Place shared mock factories inside `web/__mocks__/` and use `vi.mock('module-name')` to point to them rather than redefining mocks in every spec. - **Mocking behavior**: Modules are not mocked automatically. Use `vi.mock(...)` in tests, or place global mocks in `vitest.setup.ts`. diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 98d2976f8e..e34300993a 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -9,14 +9,6 @@ "count": 1 } }, - "__tests__/check-i18n.test.ts": { - "regexp/no-unused-capturing-group": { - "count": 1 - }, - "ts/no-explicit-any": { - "count": 2 - } - }, "__tests__/document-detail-navigation-fix.test.tsx": { "no-console": { "count": 10 diff --git a/web/i18n/ar-TN/app-debug.json b/web/i18n/ar-TN/app-debug.json index c35983397a..99c510dcb5 100644 --- a/web/i18n/ar-TN/app-debug.json +++ b/web/i18n/ar-TN/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "تشغيل", "inputs.title": "تصحيح ومعاينة", "inputs.userInputField": "حقل إدخال المستخدم", + "manageModels": "إدارة النماذج", "modelConfig.modeType.chat": "دردشة", "modelConfig.modeType.completion": "إكمال", "modelConfig.model": "نموذج", "modelConfig.setTone": "تعيين نبرة الاستجابات", "modelConfig.title": "النموذج والمعلمات", + "noModelProviderConfigured": "لم يتم تكوين أي مزوّد نماذج", + "noModelProviderConfiguredTip": "قم بتثبيت أو تكوين مزوّد نماذج للبدء.", + "noModelSelected": "لم يتم اختيار أي نموذج", + "noModelSelectedTip": "قم بتكوين نموذج أعلاه للمتابعة.", "noResult": "سيتم عرض الإخراج هنا.", "notSetAPIKey.description": "لم يتم تعيين مفتاح مزود LLM، ويجب تعيينه قبل تصحيح الأخطاء.", "notSetAPIKey.settingBtn": "الذهاب إلى الإعدادات", diff --git a/web/i18n/ar-TN/common.json b/web/i18n/ar-TN/common.json index 74fe80bcff..3bc7c05564 100644 --- a/web/i18n/ar-TN/common.json +++ b/web/i18n/ar-TN/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "غير مصرح به", "modelProvider.buyQuota": "شراء حصة", "modelProvider.callTimes": "أوقات الاتصال", + "modelProvider.card.aiCreditsInUse": "أرصدة الذكاء الاصطناعي قيد الاستخدام", + "modelProvider.card.aiCreditsOption": "أرصدة الذكاء الاصطناعي", + "modelProvider.card.apiKeyOption": "مفتاح API", + "modelProvider.card.apiKeyRequired": "مفتاح API مطلوب", + "modelProvider.card.apiKeyUnavailableFallback": "مفتاح API غير متاح، يتم الآن استخدام أرصدة الذكاء الاصطناعي", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "تحقق من تكوين مفتاح API الخاص بك للتبديل مرة أخرى", "modelProvider.card.buyQuota": "شراء حصة", "modelProvider.card.callTimes": "أوقات الاتصال", + "modelProvider.card.creditsExhaustedDescription": "يرجى ترقية خطتك أو تكوين مفتاح API", + "modelProvider.card.creditsExhaustedFallback": "نفدت أرصدة الذكاء الاصطناعي، يتم الآن استخدام مفتاح API", + "modelProvider.card.creditsExhaustedFallbackDescription": "قم بترقية خطتك لاستئناف أولوية أرصدة الذكاء الاصطناعي.", + "modelProvider.card.creditsExhaustedMessage": "نفدت أرصدة الذكاء الاصطناعي", "modelProvider.card.modelAPI": "نماذج {{modelName}} تستخدم مفتاح API.", "modelProvider.card.modelNotSupported": "نماذج {{modelName}} غير مثبتة.", "modelProvider.card.modelSupported": "نماذج {{modelName}} تستخدم هذا الحصة.", + "modelProvider.card.noApiKeysDescription": "أضف مفتاح API لبدء استخدام بيانات اعتماد النموذج الخاصة بك.", + "modelProvider.card.noApiKeysFallback": "لا توجد مفاتيح API، يتم استخدام أرصدة الذكاء الاصطناعي بدلاً من ذلك", + "modelProvider.card.noApiKeysTitle": "لم يتم تكوين أي مفاتيح API بعد", + "modelProvider.card.noAvailableUsage": "لا يوجد استخدام متاح", "modelProvider.card.onTrial": "في التجربة", "modelProvider.card.paid": "مدفوع", "modelProvider.card.priorityUse": "أولوية الاستخدام", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "إزالة مفتاح API", "modelProvider.card.tip": "تدعم أرصدة الرسائل نماذج من {{modelNames}}. ستعطى الأولوية للحصة المدفوعة. سيتم استخدام الحصة المجانية بعد نفاد الحصة المدفوعة.", "modelProvider.card.tokens": "رموز", + "modelProvider.card.unavailable": "غير متاح", + "modelProvider.card.upgradePlan": "ترقية خطتك", + "modelProvider.card.usageLabel": "الاستخدام", + "modelProvider.card.usagePriority": "أولوية الاستخدام", + "modelProvider.card.usagePriorityTip": "تعيين المورد الذي يجب استخدامه أولاً عند تشغيل النماذج.", "modelProvider.collapse": "طي", "modelProvider.config": "تكوين", "modelProvider.configLoadBalancing": "تكوين موازنة التحميل", @@ -387,9 +406,11 @@ "modelProvider.model": "النموذج", "modelProvider.modelAndParameters": "النموذج والمعلمات", "modelProvider.modelHasBeenDeprecated": "تم إهمال هذا النموذج", + "modelProvider.modelSettings": "إعدادات النموذج", "modelProvider.models": "النماذج", "modelProvider.modelsNum": "{{num}} نماذج", "modelProvider.noModelFound": "لم يتم العثور على نموذج لـ {{model}}", + "modelProvider.noneConfigured": "قم بتكوين نموذج نظام افتراضي لتشغيل التطبيقات", "modelProvider.notConfigured": "لم يتم تكوين نموذج النظام بالكامل بعد", "modelProvider.parameters": "المعلمات", "modelProvider.parametersInvalidRemoved": "بعض المعلمات غير صالحة وتمت إزالتها", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "إعادة التعيين في {{date}}", "modelProvider.searchModel": "نموذج البحث", "modelProvider.selectModel": "اختر نموذجك", + "modelProvider.selector.aiCredits": "أرصدة الذكاء الاصطناعي", + "modelProvider.selector.apiKeyUnavailable": "مفتاح API غير متاح", + "modelProvider.selector.apiKeyUnavailableTip": "تمت إزالة مفتاح API. يرجى تكوين مفتاح API جديد.", + "modelProvider.selector.configure": "تكوين", + "modelProvider.selector.configureRequired": "التكوين مطلوب", + "modelProvider.selector.creditsExhausted": "نفدت الأرصدة", + "modelProvider.selector.creditsExhaustedTip": "نفدت أرصدة الذكاء الاصطناعي الخاصة بك. يرجى ترقية خطتك أو إضافة مفتاح API.", + "modelProvider.selector.disabled": "معطل", + "modelProvider.selector.discoverMoreInMarketplace": "اكتشف المزيد في السوق", "modelProvider.selector.emptySetting": "يرجى الانتقال إلى الإعدادات للتكوين", "modelProvider.selector.emptyTip": "لا توجد نماذج متاحة", + "modelProvider.selector.fromMarketplace": "من السوق", + "modelProvider.selector.incompatible": "غير متوافق", + "modelProvider.selector.incompatibleTip": "هذا النموذج غير متاح في الإصدار الحالي. يرجى تحديد نموذج متاح آخر.", + "modelProvider.selector.install": "تثبيت", + "modelProvider.selector.modelProviderSettings": "إعدادات مزود النموذج", + "modelProvider.selector.noProviderConfigured": "لم يتم تكوين أي مزود نموذج", + "modelProvider.selector.noProviderConfiguredDesc": "تصفح السوق لتثبيت مزود، أو قم بتكوين المزودين في الإعدادات.", + "modelProvider.selector.onlyCompatibleModelsShown": "يتم عرض النماذج المتوافقة فقط", "modelProvider.selector.rerankTip": "يرجى إعداد نموذج إعادة الترتيب", "modelProvider.selector.tip": "تمت إزالة هذا النموذج. يرجى إضافة نموذج أو تحديد نموذج آخر.", "modelProvider.setupModelFirst": "يرجى إعداد نموذجك أولاً", diff --git a/web/i18n/ar-TN/plugin.json b/web/i18n/ar-TN/plugin.json index 39052f4295..c48a15fb24 100644 --- a/web/i18n/ar-TN/plugin.json +++ b/web/i18n/ar-TN/plugin.json @@ -3,6 +3,7 @@ "action.delete": "إزالة الإضافة", "action.deleteContentLeft": "هل ترغب في إزالة ", "action.deleteContentRight": " الإضافة؟", + "action.deleteSuccess": "تم حذف الإضافة بنجاح", "action.pluginInfo": "معلومات الإضافة", "action.usedInApps": "يتم استخدام هذه الإضافة في {{num}} تطبيقات.", "allCategories": "جميع الفئات", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "تثبيت", "detailPanel.operation.remove": "إزالة", "detailPanel.operation.update": "تحديث", + "detailPanel.operation.updateTooltip": "قم بالتحديث للوصول إلى أحدث النماذج.", "detailPanel.operation.viewDetail": "عرض التفاصيل", "detailPanel.serviceOk": "الخدمة جيدة", "detailPanel.strategyNum": "{{num}} {{strategy}} متضمن", @@ -231,12 +233,18 @@ "source.local": "ملف الحزمة المحلية", "source.marketplace": "السوق", "task.clearAll": "مسح الكل", + "task.errorMsg.github": "لم يتم تثبيت هذه الإضافة تلقائيًا.\nيرجى تثبيتها من GitHub.", + "task.errorMsg.marketplace": "لم يتم تثبيت هذه الإضافة تلقائيًا.\nيرجى تثبيتها من السوق.", + "task.errorMsg.unknown": "لم يتم تثبيت هذه الإضافة.\nتعذر التعرف على مصدر الإضافة.", "task.errorPlugins": "فشل في تثبيت الإضافات", "task.installError": "{{errorLength}} إضافات فشل تثبيتها، انقر للعرض", + "task.installFromGithub": "التثبيت من GitHub", + "task.installFromMarketplace": "التثبيت من السوق", "task.installSuccess": "تم تثبيت {{successLength}} من الإضافات بنجاح", "task.installed": "مثبت", "task.installedError": "{{errorLength}} إضافات فشل تثبيتها", "task.installing": "جارٍ تثبيت الإضافات.", + "task.installingHint": "جارٍ التثبيت... قد يستغرق هذا بضع دقائق.", "task.installingWithError": "تثبيت {{installingLength}} إضافات، {{successLength}} نجاح، {{errorLength}} فشل", "task.installingWithSuccess": "تثبيت {{installingLength}} إضافات، {{successLength}} نجاح.", "task.runningPlugins": "تثبيت الإضافات", diff --git a/web/i18n/ar-TN/workflow.json b/web/i18n/ar-TN/workflow.json index 1bb5237ff7..2487538071 100644 --- a/web/i18n/ar-TN/workflow.json +++ b/web/i18n/ar-TN/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "تحديث سير العمل", "error.startNodeRequired": "الرجاء إضافة عقدة البداية أولاً قبل {{operation}}", "errorMsg.authRequired": "الترخيص مطلوب", + "errorMsg.configureModel": "قم بتكوين نموذج", "errorMsg.fieldRequired": "{{field}} مطلوب", "errorMsg.fields.code": "الكود", "errorMsg.fields.model": "النموذج", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "متغير الرؤية", "errorMsg.invalidJson": "{{field}} هو JSON غير صالح", "errorMsg.invalidVariable": "متغير غير صالح", + "errorMsg.modelPluginNotInstalled": "متغير غير صالح. قم بتكوين نموذج لتمكين هذا المتغير.", "errorMsg.noValidTool": "{{field}} لا توجد أداة صالحة محددة", "errorMsg.rerankModelRequired": "مطلوب تكوين نموذج Rerank", "errorMsg.startNodeRequired": "الرجاء إضافة عقدة البداية أولاً قبل {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "حجم النافذة", "nodes.common.outputVars": "متغيرات الإخراج", "nodes.common.pluginNotInstalled": "الإضافة غير مثبتة", + "nodes.common.pluginsNotInstalled": "{{count}} إضافات غير مثبتة", "nodes.common.retry.maxRetries": "الحد الأقصى لإعادة المحاولة", "nodes.common.retry.ms": "مللي ثانية", "nodes.common.retry.retries": "{{num}} إعادة محاولة", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "القطع", "nodes.knowledgeBase.chunksInputTip": "متغير الإدخال لعقدة قاعدة المعرفة هو Pieces. نوع المتغير هو كائن بمخطط JSON محدد يجب أن يكون متسقًا مع هيكل القطعة المحدد.", "nodes.knowledgeBase.chunksVariableIsRequired": "متغير القطع مطلوب", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "مفتاح API غير متاح", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "نفدت الأرصدة", + "nodes.knowledgeBase.embeddingModelIncompatible": "غير متوافق", "nodes.knowledgeBase.embeddingModelIsInvalid": "نموذج التضمين غير صالح", "nodes.knowledgeBase.embeddingModelIsRequired": "نموذج التضمين مطلوب", + "nodes.knowledgeBase.embeddingModelNotConfigured": "نموذج التضمين غير مكوّن", "nodes.knowledgeBase.indexMethodIsRequired": "طريقة الفهرسة مطلوبة", + "nodes.knowledgeBase.notConfigured": "غير مكوّن", "nodes.knowledgeBase.rerankingModelIsInvalid": "نموذج إعادة الترتيب غير صالح", "nodes.knowledgeBase.rerankingModelIsRequired": "نموذج إعادة الترتيب مطلوب", "nodes.knowledgeBase.retrievalSettingIsRequired": "إعداد الاسترجاع مطلوب", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "يدعم Jinja2 فقط", "nodes.templateTransform.inputVars": "متغيرات الإدخال", "nodes.templateTransform.outputVars.output": "المحتوى المحول", + "nodes.tool.authorizationRequired": "التفويض مطلوب", "nodes.tool.authorize": "تخويل", "nodes.tool.inputVars": "متغيرات الإدخال", "nodes.tool.insertPlaceholder1": "اكتب أو اضغط", @@ -1062,10 +1071,12 @@ "panel.change": "تغيير", "panel.changeBlock": "تغيير العقدة", "panel.checklist": "قائمة المراجعة", + "panel.checklistDescription": "حل المشكلات التالية قبل النشر", "panel.checklistResolved": "تم حل جميع المشكلات", "panel.checklistTip": "تأكد من حل جميع المشكلات قبل النشر", "panel.createdBy": "تم الإنشاء بواسطة ", "panel.goTo": "الذهاب إلى", + "panel.goToFix": "الذهاب إلى الإصلاح", "panel.helpLink": "عرض المستندات", "panel.maximize": "تكبير القماش", "panel.minimize": "خروج من وضع ملء الشاشة", diff --git a/web/i18n/de-DE/app-debug.json b/web/i18n/de-DE/app-debug.json index f75982bee9..0e067e20eb 100644 --- a/web/i18n/de-DE/app-debug.json +++ b/web/i18n/de-DE/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "AUSFÜHREN", "inputs.title": "Debug und Vorschau", "inputs.userInputField": "Benutzereingabefeld", + "manageModels": "Modelle verwalten", "modelConfig.modeType.chat": "Chat", "modelConfig.modeType.completion": "Vollständig", "modelConfig.model": "Modell", "modelConfig.setTone": "Ton der Antworten festlegen", "modelConfig.title": "Modell und Parameter", + "noModelProviderConfigured": "Kein Modellanbieter konfiguriert", + "noModelProviderConfiguredTip": "Installieren oder konfigurieren Sie einen Modellanbieter, um zu beginnen.", + "noModelSelected": "Kein Modell ausgewählt", + "noModelSelectedTip": "Konfigurieren Sie oben ein Modell, um fortzufahren.", "noResult": "Hier wird die Ausgabe angezeigt.", "notSetAPIKey.description": "Der LLM-Anbieterschlüssel wurde nicht festgelegt und muss vor dem Debuggen festgelegt werden.", "notSetAPIKey.settingBtn": "Zu den Einstellungen gehen", diff --git a/web/i18n/de-DE/common.json b/web/i18n/de-DE/common.json index a7aaf55086..8639a24f3e 100644 --- a/web/i18n/de-DE/common.json +++ b/web/i18n/de-DE/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Unbefugt", "modelProvider.buyQuota": "Kontingent kaufen", "modelProvider.callTimes": "Anrufzeiten", + "modelProvider.card.aiCreditsInUse": "AI Credits werden verwendet", + "modelProvider.card.aiCreditsOption": "AI Credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "API Key erforderlich", + "modelProvider.card.apiKeyUnavailableFallback": "API Key nicht verfügbar, AI Credits werden verwendet", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Überprüfen Sie Ihre API-Key-Konfiguration, um zurückzuwechseln", "modelProvider.card.buyQuota": "Kontingent kaufen", "modelProvider.card.callTimes": "Anrufzeiten", + "modelProvider.card.creditsExhaustedDescription": "Bitte upgraden Sie Ihren Plan oder konfigurieren Sie einen API Key", + "modelProvider.card.creditsExhaustedFallback": "AI Credits aufgebraucht, API Key wird verwendet", + "modelProvider.card.creditsExhaustedFallbackDescription": "Upgraden Sie Ihren Plan, um die AI-Credit-Priorität wiederherzustellen.", + "modelProvider.card.creditsExhaustedMessage": "AI Credits wurden aufgebraucht", "modelProvider.card.modelAPI": "{{modelName}}-Modelle verwenden den API-Schlüssel.", "modelProvider.card.modelNotSupported": "{{modelName}}-Modelle sind nicht installiert.", "modelProvider.card.modelSupported": "{{modelName}}-Modelle verwenden dieses Kontingent.", + "modelProvider.card.noApiKeysDescription": "Fügen Sie einen API Key hinzu, um Ihre eigenen Modell-Zugangsdaten zu verwenden.", + "modelProvider.card.noApiKeysFallback": "Keine API Keys, AI Credits werden verwendet", + "modelProvider.card.noApiKeysTitle": "Noch keine API Keys konfiguriert", + "modelProvider.card.noAvailableUsage": "Kein verfügbares Guthaben", "modelProvider.card.onTrial": "In Probe", "modelProvider.card.paid": "Bezahlt", "modelProvider.card.priorityUse": "Priorisierte Nutzung", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "API-Schlüssel entfernen", "modelProvider.card.tip": "Nachrichtenguthaben unterstützen Modelle von {{modelNames}}. Der bezahlten Kontingent wird Vorrang gegeben. Das kostenlose Kontingent wird nach dem Verbrauch des bezahlten Kontingents verwendet.", "modelProvider.card.tokens": "Token", + "modelProvider.card.unavailable": "Nicht verfügbar", + "modelProvider.card.upgradePlan": "Plan upgraden", + "modelProvider.card.usageLabel": "Verbrauch", + "modelProvider.card.usagePriority": "Nutzungspriorität", + "modelProvider.card.usagePriorityTip": "Legen Sie fest, welche Ressource beim Ausführen von Modellen zuerst verwendet wird.", "modelProvider.collapse": "Einklappen", "modelProvider.config": "Konfigurieren", "modelProvider.configLoadBalancing": "Lastenausgleich für die Konfiguration", @@ -387,9 +406,11 @@ "modelProvider.model": "Modell", "modelProvider.modelAndParameters": "Modell und Parameter", "modelProvider.modelHasBeenDeprecated": "Dieses Modell ist veraltet", + "modelProvider.modelSettings": "Modelleinstellungen", "modelProvider.models": "Modelle", "modelProvider.modelsNum": "{{num}} Modelle", "modelProvider.noModelFound": "Kein Modell für {{model}} gefunden", + "modelProvider.noneConfigured": "Konfigurieren Sie ein Standard-Systemmodell, um Anwendungen auszuführen", "modelProvider.notConfigured": "Das Systemmodell wurde noch nicht vollständig konfiguriert, und einige Funktionen sind möglicherweise nicht verfügbar.", "modelProvider.parameters": "PARAMETER", "modelProvider.parametersInvalidRemoved": "Einige Parameter sind ungültig und wurden entfernt.", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Zurücksetzen am {{date}}", "modelProvider.searchModel": "Suchmodell", "modelProvider.selectModel": "Wählen Sie Ihr Modell", + "modelProvider.selector.aiCredits": "AI Credits", + "modelProvider.selector.apiKeyUnavailable": "API Key nicht verfügbar", + "modelProvider.selector.apiKeyUnavailableTip": "Der API Key wurde entfernt. Bitte konfigurieren Sie einen neuen API Key.", + "modelProvider.selector.configure": "Konfigurieren", + "modelProvider.selector.configureRequired": "Konfiguration erforderlich", + "modelProvider.selector.creditsExhausted": "Guthaben aufgebraucht", + "modelProvider.selector.creditsExhaustedTip": "Ihre AI Credits wurden aufgebraucht. Bitte upgraden Sie Ihren Plan oder fügen Sie einen API Key hinzu.", + "modelProvider.selector.disabled": "Deaktiviert", + "modelProvider.selector.discoverMoreInMarketplace": "Mehr im Marketplace entdecken", "modelProvider.selector.emptySetting": "Bitte gehen Sie zu den Einstellungen, um zu konfigurieren", "modelProvider.selector.emptyTip": "Keine verfügbaren Modelle", + "modelProvider.selector.fromMarketplace": "Vom Marketplace", + "modelProvider.selector.incompatible": "Inkompatibel", + "modelProvider.selector.incompatibleTip": "Dieses Modell ist in der aktuellen Version nicht verfügbar. Bitte wählen Sie ein anderes verfügbares Modell.", + "modelProvider.selector.install": "Installieren", + "modelProvider.selector.modelProviderSettings": "Modellanbieter-Einstellungen", + "modelProvider.selector.noProviderConfigured": "Kein Modellanbieter konfiguriert", + "modelProvider.selector.noProviderConfiguredDesc": "Durchsuchen Sie den Marketplace, um einen zu installieren, oder konfigurieren Sie Anbieter in den Einstellungen.", + "modelProvider.selector.onlyCompatibleModelsShown": "Es werden nur kompatible Modelle angezeigt", "modelProvider.selector.rerankTip": "Bitte richten Sie das Rerank-Modell ein", "modelProvider.selector.tip": "Dieses Modell wurde entfernt. Bitte fügen Sie ein Modell hinzu oder wählen Sie ein anderes Modell.", "modelProvider.setupModelFirst": "Bitte richten Sie zuerst Ihr Modell ein", diff --git a/web/i18n/de-DE/plugin.json b/web/i18n/de-DE/plugin.json index c0158be29b..e805f3fabd 100644 --- a/web/i18n/de-DE/plugin.json +++ b/web/i18n/de-DE/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Plugin entfernen", "action.deleteContentLeft": "Möchten Sie", "action.deleteContentRight": "Plugin?", + "action.deleteSuccess": "Plugin erfolgreich entfernt", "action.pluginInfo": "Plugin-Info", "action.usedInApps": "Dieses Plugin wird in {{num}} Apps verwendet.", "allCategories": "Alle Kategorien", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Installieren", "detailPanel.operation.remove": "Entfernen", "detailPanel.operation.update": "Aktualisieren", + "detailPanel.operation.updateTooltip": "Aktualisieren Sie, um auf die neuesten Modelle zuzugreifen.", "detailPanel.operation.viewDetail": "Im Detail sehen", "detailPanel.serviceOk": "Service in Ordnung", "detailPanel.strategyNum": "{{num}} {{strategy}} IINKLUSIVE", @@ -231,12 +233,18 @@ "source.local": "Lokale Paketdatei", "source.marketplace": "Marktplatz", "task.clearAll": "Alle löschen", + "task.errorMsg.github": "Dieses Plugin konnte nicht automatisch installiert werden.\nBitte installieren Sie es von GitHub.", + "task.errorMsg.marketplace": "Dieses Plugin konnte nicht automatisch installiert werden.\nBitte installieren Sie es vom Marketplace.", + "task.errorMsg.unknown": "Dieses Plugin konnte nicht installiert werden.\nDie Plugin-Quelle konnte nicht identifiziert werden.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} Plugins konnten nicht installiert werden, klicken Sie hier, um sie anzusehen", + "task.installFromGithub": "Von GitHub installieren", + "task.installFromMarketplace": "Vom Marketplace installieren", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} Plugins konnten nicht installiert werden", "task.installing": "Plugins werden installiert.", + "task.installingHint": "Installation läuft... Dies kann einige Minuten dauern.", "task.installingWithError": "Installation von {{installingLength}} Plugins, {{successLength}} erfolgreich, {{errorLength}} fehlgeschlagen", "task.installingWithSuccess": "Installation von {{installingLength}} Plugins, {{successLength}} erfolgreich.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/de-DE/workflow.json b/web/i18n/de-DE/workflow.json index 48037ce006..362eac19b6 100644 --- a/web/i18n/de-DE/workflow.json +++ b/web/i18n/de-DE/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "Arbeitsablauf aktualisieren", "error.startNodeRequired": "Bitte füge zuerst einen Startknoten hinzu, bevor du {{operation}}.", "errorMsg.authRequired": "Autorisierung ist erforderlich", + "errorMsg.configureModel": "Modell konfigurieren", "errorMsg.fieldRequired": "{{field}} ist erforderlich", "errorMsg.fields.code": "Code", "errorMsg.fields.model": "Modell", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Vision variabel", "errorMsg.invalidJson": "{{field}} ist ein ungültiges JSON", "errorMsg.invalidVariable": "Ungültige Variable", + "errorMsg.modelPluginNotInstalled": "Ungültige Variable. Konfigurieren Sie ein Modell, um diese Variable zu aktivieren.", "errorMsg.noValidTool": "{{field}} kein gültiges Werkzeug ausgewählt", "errorMsg.rerankModelRequired": "Bevor Sie das Rerank-Modell aktivieren, bestätigen Sie bitte, dass das Modell in den Einstellungen erfolgreich konfiguriert wurde.", "errorMsg.startNodeRequired": "Bitte füge zuerst einen Startknoten hinzu, bevor du {{operation}}.", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Fenstergröße", "nodes.common.outputVars": "Ausgabevariablen", "nodes.common.pluginNotInstalled": "Plugin ist nicht installiert", + "nodes.common.pluginsNotInstalled": "{{count}} Plugins nicht installiert", "nodes.common.retry.maxRetries": "Max. Wiederholungen", "nodes.common.retry.ms": "Frau", "nodes.common.retry.retries": "{{num}} Wiederholungen", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Stücke", "nodes.knowledgeBase.chunksInputTip": "Die Eingangsvariable des Wissensbasis-Knotens sind Chunks. Der Variablentyp ist ein Objekt mit einem spezifischen JSON-Schema, das konsistent mit der ausgewählten Chunk-Struktur sein muss.", "nodes.knowledgeBase.chunksVariableIsRequired": "Die Variable 'Chunks' ist erforderlich", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key nicht verfügbar", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Guthaben aufgebraucht", + "nodes.knowledgeBase.embeddingModelIncompatible": "Inkompatibel", "nodes.knowledgeBase.embeddingModelIsInvalid": "Einbettungsmodell ist ungültig", "nodes.knowledgeBase.embeddingModelIsRequired": "Ein Einbettungsmodell ist erforderlich", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Embedding-Modell nicht konfiguriert", "nodes.knowledgeBase.indexMethodIsRequired": "Index-Methode ist erforderlich", + "nodes.knowledgeBase.notConfigured": "Nicht konfiguriert", "nodes.knowledgeBase.rerankingModelIsInvalid": "Das Reranking-Modell ist ungültig", "nodes.knowledgeBase.rerankingModelIsRequired": "Ein Reranking-Modell ist erforderlich", "nodes.knowledgeBase.retrievalSettingIsRequired": "Abrufeinstellung ist erforderlich", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Unterstützt nur Jinja2", "nodes.templateTransform.inputVars": "Eingabevariablen", "nodes.templateTransform.outputVars.output": "Transformierter Inhalt", + "nodes.tool.authorizationRequired": "Autorisierung erforderlich", "nodes.tool.authorize": "Autorisieren", "nodes.tool.inputVars": "Eingabevariablen", "nodes.tool.insertPlaceholder1": "Tippen oder drücken", @@ -1062,10 +1071,12 @@ "panel.change": "Ändern", "panel.changeBlock": "Knoten ändern", "panel.checklist": "Checkliste", + "panel.checklistDescription": "Beheben Sie die folgenden Probleme vor der Veröffentlichung", "panel.checklistResolved": "Alle Probleme wurden gelöst", "panel.checklistTip": "Stellen Sie sicher, dass alle Probleme vor der Veröffentlichung gelöst sind", "panel.createdBy": "Erstellt von ", "panel.goTo": "Gehe zu", + "panel.goToFix": "Zur Behebung", "panel.helpLink": "Hilfe", "panel.maximize": "Maximiere die Leinwand", "panel.minimize": "Vollbildmodus beenden", diff --git a/web/i18n/es-ES/app-debug.json b/web/i18n/es-ES/app-debug.json index 8245f2b325..cfb8a0643e 100644 --- a/web/i18n/es-ES/app-debug.json +++ b/web/i18n/es-ES/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "EJECUTAR", "inputs.title": "Depurar y Previsualizar", "inputs.userInputField": "Campo de Entrada del Usuario", + "manageModels": "Gestionar modelos", "modelConfig.modeType.chat": "Chat", "modelConfig.modeType.completion": "Completar", "modelConfig.model": "Modelo", "modelConfig.setTone": "Establecer tono de respuestas", "modelConfig.title": "Modelo y Parámetros", + "noModelProviderConfigured": "Ningún proveedor de modelos configurado", + "noModelProviderConfiguredTip": "Instala o configura un proveedor de modelos para comenzar.", + "noModelSelected": "Ningún modelo seleccionado", + "noModelSelectedTip": "Configura un modelo arriba para continuar.", "noResult": "La salida se mostrará aquí.", "notSetAPIKey.description": "La clave del proveedor LLM no se ha establecido, y debe configurarse antes de depurar.", "notSetAPIKey.settingBtn": "Ir a configuración", diff --git a/web/i18n/es-ES/common.json b/web/i18n/es-ES/common.json index 595124496b..1b97ce680d 100644 --- a/web/i18n/es-ES/common.json +++ b/web/i18n/es-ES/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "No autorizado", "modelProvider.buyQuota": "Comprar Cuota", "modelProvider.callTimes": "Tiempos de llamada", + "modelProvider.card.aiCreditsInUse": "AI credits en uso", + "modelProvider.card.aiCreditsOption": "AI credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "Se requiere API Key", + "modelProvider.card.apiKeyUnavailableFallback": "API Key no disponible, usando AI credits", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Comprueba la configuración de tu API Key para volver a usarla", "modelProvider.card.buyQuota": "Comprar Cuota", "modelProvider.card.callTimes": "Tiempos de llamada", + "modelProvider.card.creditsExhaustedDescription": "Por favor, mejora tu plan o configura una API Key", + "modelProvider.card.creditsExhaustedFallback": "AI credits agotados, usando API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Mejora tu plan para restablecer la prioridad de AI credits.", + "modelProvider.card.creditsExhaustedMessage": "Los AI credits se han agotado", "modelProvider.card.modelAPI": "Los modelos {{modelName}} están usando la clave API.", "modelProvider.card.modelNotSupported": "Los modelos {{modelName}} no están instalados.", "modelProvider.card.modelSupported": "Los modelos {{modelName}} están usando esta cuota.", + "modelProvider.card.noApiKeysDescription": "Añade una API Key para usar tus propias credenciales de modelo.", + "modelProvider.card.noApiKeysFallback": "Sin API Keys, usando AI credits", + "modelProvider.card.noApiKeysTitle": "No hay API Keys configuradas", + "modelProvider.card.noAvailableUsage": "Sin uso disponible", "modelProvider.card.onTrial": "En prueba", "modelProvider.card.paid": "Pagado", "modelProvider.card.priorityUse": "Uso prioritario", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Eliminar CLAVE API", "modelProvider.card.tip": "Créditos de mensajes admite modelos de {{modelNames}}. Se dará prioridad a la cuota pagada. La cuota gratuita se utilizará después de que se agote la cuota pagada.", "modelProvider.card.tokens": "Tokens", + "modelProvider.card.unavailable": "No disponible", + "modelProvider.card.upgradePlan": "mejora tu plan", + "modelProvider.card.usageLabel": "Uso", + "modelProvider.card.usagePriority": "Prioridad de uso", + "modelProvider.card.usagePriorityTip": "Establece qué recurso usar primero al ejecutar modelos.", "modelProvider.collapse": "Colapsar", "modelProvider.config": "Configurar", "modelProvider.configLoadBalancing": "Configurar Balanceo de Carga", @@ -387,9 +406,11 @@ "modelProvider.model": "Modelo", "modelProvider.modelAndParameters": "Modelo y Parámetros", "modelProvider.modelHasBeenDeprecated": "Este modelo ha sido desaprobado", + "modelProvider.modelSettings": "Configuración de modelos", "modelProvider.models": "Modelos", "modelProvider.modelsNum": "{{num}} Modelos", "modelProvider.noModelFound": "No se encontró modelo para {{model}}", + "modelProvider.noneConfigured": "Configura un modelo de sistema predeterminado para ejecutar aplicaciones", "modelProvider.notConfigured": "El modelo del sistema aún no ha sido completamente configurado, y algunas funciones pueden no estar disponibles.", "modelProvider.parameters": "PARÁMETROS", "modelProvider.parametersInvalidRemoved": "Algunos parámetros son inválidos y han sido eliminados", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Restablecer el {{date}}", "modelProvider.searchModel": "Modelo de búsqueda", "modelProvider.selectModel": "Selecciona tu modelo", + "modelProvider.selector.aiCredits": "AI credits", + "modelProvider.selector.apiKeyUnavailable": "API Key no disponible", + "modelProvider.selector.apiKeyUnavailableTip": "La API Key ha sido eliminada. Por favor, configura una nueva API Key.", + "modelProvider.selector.configure": "Configurar", + "modelProvider.selector.configureRequired": "Configuración requerida", + "modelProvider.selector.creditsExhausted": "Créditos agotados", + "modelProvider.selector.creditsExhaustedTip": "Tus AI credits se han agotado. Por favor, mejora tu plan o añade una API Key.", + "modelProvider.selector.disabled": "Desactivado", + "modelProvider.selector.discoverMoreInMarketplace": "Descubre más en el Marketplace", "modelProvider.selector.emptySetting": "Por favor ve a configuraciones para configurar", "modelProvider.selector.emptyTip": "No hay modelos disponibles", + "modelProvider.selector.fromMarketplace": "Desde el Marketplace", + "modelProvider.selector.incompatible": "Incompatible", + "modelProvider.selector.incompatibleTip": "Este modelo no está disponible en la versión actual. Por favor, selecciona otro modelo disponible.", + "modelProvider.selector.install": "Instalar", + "modelProvider.selector.modelProviderSettings": "Configuración del proveedor de modelos", + "modelProvider.selector.noProviderConfigured": "Ningún proveedor de modelos configurado", + "modelProvider.selector.noProviderConfiguredDesc": "Explora el Marketplace para instalar uno, o configura proveedores en los ajustes.", + "modelProvider.selector.onlyCompatibleModelsShown": "Solo se muestran los modelos compatibles", "modelProvider.selector.rerankTip": "Por favor configura el modelo de Reordenar", "modelProvider.selector.tip": "Este modelo ha sido eliminado. Por favor agrega un modelo o selecciona otro modelo.", "modelProvider.setupModelFirst": "Por favor configura tu modelo primero", diff --git a/web/i18n/es-ES/plugin.json b/web/i18n/es-ES/plugin.json index 297b86606d..c7d5855f24 100644 --- a/web/i18n/es-ES/plugin.json +++ b/web/i18n/es-ES/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Eliminar plugin", "action.deleteContentLeft": "¿Le gustaría eliminar", "action.deleteContentRight": "¿Complemento?", + "action.deleteSuccess": "Plugin eliminado correctamente", "action.pluginInfo": "Información del plugin", "action.usedInApps": "Este plugin se está utilizando en las aplicaciones {{num}}.", "allCategories": "Todas las categorías", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Instalar", "detailPanel.operation.remove": "Eliminar", "detailPanel.operation.update": "Actualizar", + "detailPanel.operation.updateTooltip": "Actualiza para acceder a los modelos más recientes.", "detailPanel.operation.viewDetail": "Ver Detalle", "detailPanel.serviceOk": "Servicio OK", "detailPanel.strategyNum": "{{num}} {{strategy}} INCLUIDO", @@ -231,12 +233,18 @@ "source.local": "Archivo de paquete local", "source.marketplace": "Mercado", "task.clearAll": "Borrar todo", + "task.errorMsg.github": "Este plugin no se pudo instalar automáticamente.\nPor favor, instálalo desde GitHub.", + "task.errorMsg.marketplace": "Este plugin no se pudo instalar automáticamente.\nPor favor, instálalo desde el Marketplace.", + "task.errorMsg.unknown": "Este plugin no se pudo instalar.\nNo se pudo identificar el origen del plugin.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "Los complementos {{errorLength}} no se pudieron instalar, haga clic para ver", + "task.installFromGithub": "Instalar desde GitHub", + "task.installFromMarketplace": "Instalar desde el Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "Los complementos {{errorLength}} no se pudieron instalar", "task.installing": "Instalando plugins.", + "task.installingHint": "Instalando... Esto puede tardar unos minutos.", "task.installingWithError": "Instalando plugins {{installingLength}}, {{successLength}} éxito, {{errorLength}} fallido", "task.installingWithSuccess": "Instalando plugins {{installingLength}}, {{successLength}} éxito.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/es-ES/workflow.json b/web/i18n/es-ES/workflow.json index d91cbd7cb2..393859a36f 100644 --- a/web/i18n/es-ES/workflow.json +++ b/web/i18n/es-ES/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "actualizando flujo de trabajo", "error.startNodeRequired": "Por favor, agregue primero un nodo de inicio antes de {{operation}}", "errorMsg.authRequired": "Se requiere autorización", + "errorMsg.configureModel": "Configura un modelo", "errorMsg.fieldRequired": "Se requiere {{field}}", "errorMsg.fields.code": "Código", "errorMsg.fields.model": "Modelo", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Variable de visión", "errorMsg.invalidJson": "{{field}} no es un JSON válido", "errorMsg.invalidVariable": "Variable no válida", + "errorMsg.modelPluginNotInstalled": "Variable no válida. Configura un modelo para habilitar esta variable.", "errorMsg.noValidTool": "{{field}} no se ha seleccionado ninguna herramienta válida", "errorMsg.rerankModelRequired": "Antes de activar el modelo de reclasificación, confirme que el modelo se ha configurado correctamente en la configuración.", "errorMsg.startNodeRequired": "Por favor, agregue primero un nodo de inicio antes de {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Tamaño de ventana", "nodes.common.outputVars": "Variables de salida", "nodes.common.pluginNotInstalled": "El complemento no está instalado", + "nodes.common.pluginsNotInstalled": "{{count}} plugins no instalados", "nodes.common.retry.maxRetries": "Número máximo de reintentos", "nodes.common.retry.ms": "Sra.", "nodes.common.retry.retries": "{{num}} Reintentos", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Trozo", "nodes.knowledgeBase.chunksInputTip": "La variable de entrada del nodo de la base de conocimientos es Chunks. El tipo de variable es un objeto con un esquema JSON específico que debe ser consistente con la estructura del fragmento seleccionado.", "nodes.knowledgeBase.chunksVariableIsRequired": "La variable Chunks es obligatoria", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key no disponible", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Créditos agotados", + "nodes.knowledgeBase.embeddingModelIncompatible": "Incompatible", "nodes.knowledgeBase.embeddingModelIsInvalid": "El modelo de incrustación no es válido", "nodes.knowledgeBase.embeddingModelIsRequired": "Se requiere un modelo de incrustación", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Modelo de embedding no configurado", "nodes.knowledgeBase.indexMethodIsRequired": "Se requiere el método de índice", + "nodes.knowledgeBase.notConfigured": "No configurado", "nodes.knowledgeBase.rerankingModelIsInvalid": "El modelo de reordenación no es válido", "nodes.knowledgeBase.rerankingModelIsRequired": "Se requiere un modelo de reordenamiento", "nodes.knowledgeBase.retrievalSettingIsRequired": "Se requiere configuración de recuperación", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Solo admite Jinja2", "nodes.templateTransform.inputVars": "Variables de entrada", "nodes.templateTransform.outputVars.output": "Contenido transformado", + "nodes.tool.authorizationRequired": "Autorización requerida", "nodes.tool.authorize": "autorizar", "nodes.tool.inputVars": "Variables de entrada", "nodes.tool.insertPlaceholder1": "Escribe o presiona", @@ -1062,10 +1071,12 @@ "panel.change": "Cambiar", "panel.changeBlock": "Cambiar Nodo", "panel.checklist": "Lista de verificación", + "panel.checklistDescription": "Resuelve los siguientes problemas antes de publicar", "panel.checklistResolved": "Se resolvieron todos los problemas", "panel.checklistTip": "Asegúrate de resolver todos los problemas antes de publicar", "panel.createdBy": "Creado por ", "panel.goTo": "Ir a", + "panel.goToFix": "Ir a corregir", "panel.helpLink": "Ayuda", "panel.maximize": "Maximizar Canvas", "panel.minimize": "Salir de pantalla completa", diff --git a/web/i18n/fa-IR/app-debug.json b/web/i18n/fa-IR/app-debug.json index 5427ebb72a..a612105f35 100644 --- a/web/i18n/fa-IR/app-debug.json +++ b/web/i18n/fa-IR/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "اجرا", "inputs.title": "اشکال زدایی و پیش نمایش", "inputs.userInputField": "فیلد ورودی کاربر", + "manageModels": "مدیریت مدل‌ها", "modelConfig.modeType.chat": "چت", "modelConfig.modeType.completion": "کامل", "modelConfig.model": "مدل", "modelConfig.setTone": "لحن پاسخ ها را تنظیم کنید", "modelConfig.title": "مدل و پارامترها", + "noModelProviderConfigured": "هیچ ارائه‌دهنده مدلی پیکربندی نشده است", + "noModelProviderConfiguredTip": "برای شروع، یک ارائه‌دهنده مدل نصب یا پیکربندی کنید.", + "noModelSelected": "هیچ مدلی انتخاب نشده است", + "noModelSelectedTip": "برای ادامه، یک مدل را در بالا پیکربندی کنید.", "noResult": "خروجی در اینجا نمایش داده می شود.", "notSetAPIKey.description": "کلید ارائه‌دهنده LLM تنظیم نشده است و باید قبل از دیباگ تنظیم شود.", "notSetAPIKey.settingBtn": "به تنظیمات بروید", diff --git a/web/i18n/fa-IR/common.json b/web/i18n/fa-IR/common.json index ec535138a4..d2b1e8158c 100644 --- a/web/i18n/fa-IR/common.json +++ b/web/i18n/fa-IR/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "بدون مجوز", "modelProvider.buyQuota": "خرید سهمیه", "modelProvider.callTimes": "تعداد فراخوانی", + "modelProvider.card.aiCreditsInUse": "اعتبار هوش مصنوعی در حال استفاده", + "modelProvider.card.aiCreditsOption": "اعتبار هوش مصنوعی", + "modelProvider.card.apiKeyOption": "کلید API", + "modelProvider.card.apiKeyRequired": "کلید API الزامی است", + "modelProvider.card.apiKeyUnavailableFallback": "کلید API در دسترس نیست، در حال استفاده از اعتبار هوش مصنوعی", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "پیکربندی کلید API خود را بررسی کنید تا بازگردید", "modelProvider.card.buyQuota": "خرید سهمیه", "modelProvider.card.callTimes": "تعداد فراخوانی", + "modelProvider.card.creditsExhaustedDescription": "لطفاً طرح خود را ارتقا دهید یا یک کلید API پیکربندی کنید", + "modelProvider.card.creditsExhaustedFallback": "اعتبار هوش مصنوعی تمام شده، در حال استفاده از کلید API", + "modelProvider.card.creditsExhaustedFallbackDescription": "طرح خود را ارتقا دهید تا اولویت اعتبار هوش مصنوعی از سر گرفته شود.", + "modelProvider.card.creditsExhaustedMessage": "اعتبار هوش مصنوعی تمام شده است", "modelProvider.card.modelAPI": "مدل‌های {{modelName}} از کلید API استفاده می‌کنند.", "modelProvider.card.modelNotSupported": "مدل‌های {{modelName}} نصب نشده‌اند.", "modelProvider.card.modelSupported": "مدل‌های {{modelName}} از این سهمیه استفاده می‌کنند.", + "modelProvider.card.noApiKeysDescription": "یک کلید API اضافه کنید تا از اعتبارنامه‌های مدل خود استفاده کنید.", + "modelProvider.card.noApiKeysFallback": "بدون کلید API، در حال استفاده از اعتبار هوش مصنوعی", + "modelProvider.card.noApiKeysTitle": "هنوز کلید API پیکربندی نشده است", + "modelProvider.card.noAvailableUsage": "هیچ مصرفی در دسترس نیست", "modelProvider.card.onTrial": "در حال آزمایش", "modelProvider.card.paid": "پرداخت شده", "modelProvider.card.priorityUse": "استفاده با اولویت", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "حذف کلید API", "modelProvider.card.tip": "اعتبار پیام از مدل‌های {{modelNames}} پشتیبانی می‌کند. اولویت به سهمیه پرداخت شده داده می‌شود. سهمیه رایگان پس از اتمام سهمیه پرداخت شده استفاده خواهد شد.", "modelProvider.card.tokens": "توکن‌ها", + "modelProvider.card.unavailable": "در دسترس نیست", + "modelProvider.card.upgradePlan": "طرح خود را ارتقا دهید", + "modelProvider.card.usageLabel": "مصرف", + "modelProvider.card.usagePriority": "اولویت مصرف", + "modelProvider.card.usagePriorityTip": "تعیین کنید که هنگام اجرای مدل‌ها کدام منبع اول استفاده شود.", "modelProvider.collapse": "جمع کردن", "modelProvider.config": "پیکربندی", "modelProvider.configLoadBalancing": "پیکربندی تعادل بار", @@ -387,9 +406,11 @@ "modelProvider.model": "مدل", "modelProvider.modelAndParameters": "مدل و پارامترها", "modelProvider.modelHasBeenDeprecated": "این مدل منسوخ شده است", + "modelProvider.modelSettings": "تنظیمات مدل", "modelProvider.models": "مدل‌ها", "modelProvider.modelsNum": "{{num}} مدل", "modelProvider.noModelFound": "هیچ مدلی برای {{model}} یافت نشد", + "modelProvider.noneConfigured": "یک مدل سیستم پیش‌فرض برای اجرای برنامه‌ها پیکربندی کنید", "modelProvider.notConfigured": "مدل سیستم هنوز به طور کامل پیکربندی نشده است و برخی از عملکردها ممکن است در دسترس نباشند.", "modelProvider.parameters": "پارامترها", "modelProvider.parametersInvalidRemoved": "برخی پارامترها نامعتبر هستند و حذف شده‌اند", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "بازنشانی در {{date}}", "modelProvider.searchModel": "جستجوی مدل", "modelProvider.selectModel": "مدل خود را انتخاب کنید", + "modelProvider.selector.aiCredits": "اعتبار هوش مصنوعی", + "modelProvider.selector.apiKeyUnavailable": "کلید API در دسترس نیست", + "modelProvider.selector.apiKeyUnavailableTip": "کلید API حذف شده است. لطفاً یک کلید API جدید پیکربندی کنید.", + "modelProvider.selector.configure": "پیکربندی", + "modelProvider.selector.configureRequired": "پیکربندی الزامی است", + "modelProvider.selector.creditsExhausted": "اعتبار تمام شده", + "modelProvider.selector.creditsExhaustedTip": "اعتبار هوش مصنوعی شما تمام شده است. لطفاً طرح خود را ارتقا دهید یا یک کلید API اضافه کنید.", + "modelProvider.selector.disabled": "غیرفعال", + "modelProvider.selector.discoverMoreInMarketplace": "اطلاعات بیشتر در Marketplace", "modelProvider.selector.emptySetting": "لطفاً به تنظیمات بروید تا پیکربندی کنید", "modelProvider.selector.emptyTip": "هیچ مدل موجودی وجود ندارد", + "modelProvider.selector.fromMarketplace": "از Marketplace", + "modelProvider.selector.incompatible": "ناسازگار", + "modelProvider.selector.incompatibleTip": "این مدل در نسخه فعلی موجود نیست. لطفاً مدل دیگری انتخاب کنید.", + "modelProvider.selector.install": "نصب", + "modelProvider.selector.modelProviderSettings": "تنظیمات ارائه‌دهنده مدل", + "modelProvider.selector.noProviderConfigured": "هیچ ارائه‌دهنده مدلی پیکربندی نشده است", + "modelProvider.selector.noProviderConfiguredDesc": "در Marketplace جستجو کنید تا یکی نصب کنید، یا ارائه‌دهندگان را در تنظیمات پیکربندی کنید.", + "modelProvider.selector.onlyCompatibleModelsShown": "فقط مدل‌های سازگار نمایش داده می‌شوند", "modelProvider.selector.rerankTip": "لطفاً مدل رتبه‌بندی مجدد را تنظیم کنید", "modelProvider.selector.tip": "این مدل حذف شده است. لطفاً یک مدل اضافه کنید یا مدل دیگری را انتخاب کنید.", "modelProvider.setupModelFirst": "لطفاً ابتدا مدل خود را تنظیم کنید", diff --git a/web/i18n/fa-IR/plugin.json b/web/i18n/fa-IR/plugin.json index a1bbc8ee63..bee2111b44 100644 --- a/web/i18n/fa-IR/plugin.json +++ b/web/i18n/fa-IR/plugin.json @@ -3,6 +3,7 @@ "action.delete": "حذف افزونه", "action.deleteContentLeft": "آیا می خواهید", "action.deleteContentRight": "افزونه?", + "action.deleteSuccess": "افزونه با موفقیت حذف شد", "action.pluginInfo": "اطلاعات پلاگین", "action.usedInApps": "این افزونه در برنامه های {{num}} استفاده می شود.", "allCategories": "همه دسته بندی ها", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "نصب", "detailPanel.operation.remove": "حذف", "detailPanel.operation.update": "روز رسانی", + "detailPanel.operation.updateTooltip": "به‌روزرسانی کنید تا به جدیدترین مدل‌ها دسترسی پیدا کنید.", "detailPanel.operation.viewDetail": "نمایش جزئیات", "detailPanel.serviceOk": "خدمات خوب", "detailPanel.strategyNum": "{{num}} {{strategy}} شامل", @@ -231,12 +233,18 @@ "source.local": "فایل بسته محلی", "source.marketplace": "بازار", "task.clearAll": "پاک کردن همه", + "task.errorMsg.github": "این افزونه به‌صورت خودکار نصب نشد.\nلطفاً آن را از GitHub نصب کنید.", + "task.errorMsg.marketplace": "این افزونه به‌صورت خودکار نصب نشد.\nلطفاً آن را از Marketplace نصب کنید.", + "task.errorMsg.unknown": "این افزونه نصب نشد.\nمنبع افزونه قابل شناسایی نبود.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "پلاگین های {{errorLength}} نصب نشدند، برای مشاهده کلیک کنید", + "task.installFromGithub": "نصب از GitHub", + "task.installFromMarketplace": "نصب از Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "افزونه های {{errorLength}} نصب نشدند", "task.installing": "در حال نصب پلاگین‌ها.", + "task.installingHint": "در حال نصب... این ممکن است چند دقیقه طول بکشد.", "task.installingWithError": "نصب پلاگین های {{installingLength}}، {{successLength}} با موفقیت مواجه شد، {{errorLength}} ناموفق بود", "task.installingWithSuccess": "نصب پلاگین های {{installingLength}}، {{successLength}} موفقیت آمیز است.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/fa-IR/workflow.json b/web/i18n/fa-IR/workflow.json index 676179e6d4..7a8fca11f1 100644 --- a/web/i18n/fa-IR/workflow.json +++ b/web/i18n/fa-IR/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "به‌روزرسانی گردش کار", "error.startNodeRequired": "لطفاً قبل از {{operation}} ابتدا یک گره شروع اضافه کنید", "errorMsg.authRequired": "احراز هویت الزامی است", + "errorMsg.configureModel": "یک مدل پیکربندی کنید", "errorMsg.fieldRequired": "{{field}} الزامی است", "errorMsg.fields.code": "کد", "errorMsg.fields.model": "مدل", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "متغیر بینایی", "errorMsg.invalidJson": "{{field}} یک JSON معتبر نیست", "errorMsg.invalidVariable": "متغیر نامعتبر", + "errorMsg.modelPluginNotInstalled": "متغیر نامعتبر. برای فعال‌سازی این متغیر، یک مدل پیکربندی کنید.", "errorMsg.noValidTool": "{{field}} هیچ ابزار معتبری انتخاب نشده است", "errorMsg.rerankModelRequired": "قبل از فعال‌سازی Rerank Model، لطفاً مطمئن شوید که مدل در تنظیمات با موفقیت پیکربندی شده است.", "errorMsg.startNodeRequired": "لطفاً قبل از {{operation}} ابتدا یک گره شروع اضافه کنید", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "اندازه پنجره", "nodes.common.outputVars": "متغیرهای خروجی", "nodes.common.pluginNotInstalled": "افزونه نصب نشده است", + "nodes.common.pluginsNotInstalled": "{{count}} افزونه نصب نشده است", "nodes.common.retry.maxRetries": "حداکثر تلاش مجدد", "nodes.common.retry.ms": "ms", "nodes.common.retry.retries": "{{num}} تلاش مجدد", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "چانک‌ها", "nodes.knowledgeBase.chunksInputTip": "متغیر ورودی گره پایگاه دانش چانک‌ها است. نوع متغیر یک شیء با طرح JSON خاص است که باید با ساختار چانک انتخاب‌شده سازگار باشد.", "nodes.knowledgeBase.chunksVariableIsRequired": "متغیر چانک‌ها الزامی است", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "کلید API در دسترس نیست", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "اعتبار تمام شده", + "nodes.knowledgeBase.embeddingModelIncompatible": "ناسازگار", "nodes.knowledgeBase.embeddingModelIsInvalid": "مدل Embedding نامعتبر است", "nodes.knowledgeBase.embeddingModelIsRequired": "مدل Embedding الزامی است", + "nodes.knowledgeBase.embeddingModelNotConfigured": "مدل Embedding پیکربندی نشده است", "nodes.knowledgeBase.indexMethodIsRequired": "روش ایندکس‌گذاری الزامی است", + "nodes.knowledgeBase.notConfigured": "پیکربندی نشده", "nodes.knowledgeBase.rerankingModelIsInvalid": "مدل بازرتبه‌بندی نامعتبر است", "nodes.knowledgeBase.rerankingModelIsRequired": "مدل بازرتبه‌بندی الزامی است", "nodes.knowledgeBase.retrievalSettingIsRequired": "تنظیمات بازیابی الزامی است", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "فقط از Jinja2 پشتیبانی می‌شود", "nodes.templateTransform.inputVars": "متغیرهای ورودی", "nodes.templateTransform.outputVars.output": "محتوای تبدیل‌شده", + "nodes.tool.authorizationRequired": "احراز هویت الزامی است", "nodes.tool.authorize": "مجوزدهی", "nodes.tool.inputVars": "متغیرهای ورودی", "nodes.tool.insertPlaceholder1": "تایپ کنید یا فشار دهید", @@ -1062,10 +1071,12 @@ "panel.change": "تغییر", "panel.changeBlock": "تغییر گره", "panel.checklist": "چک‌لیست", + "panel.checklistDescription": "مشکلات زیر را پیش از انتشار برطرف کنید", "panel.checklistResolved": "تمام مشکلات برطرف شده‌اند", "panel.checklistTip": "قبل از انتشار، مطمئن شوید که تمام مشکلات برطرف شده‌اند", "panel.createdBy": "ساخته شده توسط", "panel.goTo": "برو به", + "panel.goToFix": "برو به رفع", "panel.helpLink": "راهنما", "panel.maximize": "تمام‌صفحه", "panel.minimize": "خروج از تمام‌صفحه", diff --git a/web/i18n/fr-FR/app-debug.json b/web/i18n/fr-FR/app-debug.json index 88f75f5136..711959c323 100644 --- a/web/i18n/fr-FR/app-debug.json +++ b/web/i18n/fr-FR/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "EXÉCUTER", "inputs.title": "Déboguer et Aperçu", "inputs.userInputField": "Champ de saisie utilisateur", + "manageModels": "Gérer les modèles", "modelConfig.modeType.chat": "Discussion", "modelConfig.modeType.completion": "Complet", "modelConfig.model": "Modèle", "modelConfig.setTone": "Définir le ton des réponses", "modelConfig.title": "Modèle et Paramètres", + "noModelProviderConfigured": "Aucun fournisseur de modèle configuré", + "noModelProviderConfiguredTip": "Installez ou configurez un fournisseur de modèle pour commencer.", + "noModelSelected": "Aucun modèle sélectionné", + "noModelSelectedTip": "Configurez un modèle ci-dessus pour continuer.", "noResult": "La sortie sera affichée ici.", "notSetAPIKey.description": "La clé du fournisseur LLM n'a pas été définie, et elle doit être définie avant le débogage.", "notSetAPIKey.settingBtn": "Aller aux paramètres", diff --git a/web/i18n/fr-FR/common.json b/web/i18n/fr-FR/common.json index 42b40b2a2a..8710c04c44 100644 --- a/web/i18n/fr-FR/common.json +++ b/web/i18n/fr-FR/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Non autorisé", "modelProvider.buyQuota": "Acheter Quota", "modelProvider.callTimes": "Temps d'appel", + "modelProvider.card.aiCreditsInUse": "AI credits en cours d'utilisation", + "modelProvider.card.aiCreditsOption": "AI credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "API Key requise", + "modelProvider.card.apiKeyUnavailableFallback": "API Key indisponible, utilisation des AI credits", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Vérifiez la configuration de votre API Key pour revenir en arrière", "modelProvider.card.buyQuota": "Acheter Quota", "modelProvider.card.callTimes": "Temps d'appel", + "modelProvider.card.creditsExhaustedDescription": "Veuillez mettre à niveau votre forfait ou configurer une API Key", + "modelProvider.card.creditsExhaustedFallback": "AI credits épuisés, utilisation de l'API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Mettez à niveau votre forfait pour rétablir la priorité des AI credits.", + "modelProvider.card.creditsExhaustedMessage": "Les AI credits ont été épuisés", "modelProvider.card.modelAPI": "Les modèles {{modelName}} utilisent la clé API.", "modelProvider.card.modelNotSupported": "Les modèles {{modelName}} ne sont pas installés.", "modelProvider.card.modelSupported": "Les modèles {{modelName}} utilisent ce quota.", + "modelProvider.card.noApiKeysDescription": "Ajoutez une API Key pour utiliser vos propres identifiants de modèle.", + "modelProvider.card.noApiKeysFallback": "Aucune API Key, utilisation des AI credits", + "modelProvider.card.noApiKeysTitle": "Aucune API Key configurée", + "modelProvider.card.noAvailableUsage": "Aucune utilisation disponible", "modelProvider.card.onTrial": "En Essai", "modelProvider.card.paid": "Payé", "modelProvider.card.priorityUse": "Utilisation prioritaire", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Supprimer la clé API", "modelProvider.card.tip": "Les crédits de messages prennent en charge les modèles de {{modelNames}}. La priorité sera donnée au quota payant. Le quota gratuit sera utilisé après épuisement du quota payant.", "modelProvider.card.tokens": "Jetons", + "modelProvider.card.unavailable": "Indisponible", + "modelProvider.card.upgradePlan": "mettre à niveau votre forfait", + "modelProvider.card.usageLabel": "Utilisation", + "modelProvider.card.usagePriority": "Priorité d'utilisation", + "modelProvider.card.usagePriorityTip": "Définissez la ressource à utiliser en priorité lors de l'exécution des modèles.", "modelProvider.collapse": "Effondrer", "modelProvider.config": "Configuration", "modelProvider.configLoadBalancing": "Équilibrage de charge de configuration", @@ -387,9 +406,11 @@ "modelProvider.model": "Modèle", "modelProvider.modelAndParameters": "Modèle et Paramètres", "modelProvider.modelHasBeenDeprecated": "Ce modèle est obsolète", + "modelProvider.modelSettings": "Paramètres du modèle", "modelProvider.models": "Modèles", "modelProvider.modelsNum": "{{num}} Modèles", "modelProvider.noModelFound": "Aucun modèle trouvé pour {{model}}", + "modelProvider.noneConfigured": "Configurez un modèle système par défaut pour exécuter les applications", "modelProvider.notConfigured": "Le modèle du système n'a pas encore été entièrement configuré, et certaines fonctions peuvent être indisponibles.", "modelProvider.parameters": "PARAMÈTRES", "modelProvider.parametersInvalidRemoved": "Certains paramètres sont invalides et ont été supprimés.", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Réinitialiser le {{date}}", "modelProvider.searchModel": "Modèle de recherche", "modelProvider.selectModel": "Sélectionnez votre modèle", + "modelProvider.selector.aiCredits": "AI credits", + "modelProvider.selector.apiKeyUnavailable": "API Key indisponible", + "modelProvider.selector.apiKeyUnavailableTip": "L'API Key a été supprimée. Veuillez configurer une nouvelle API Key.", + "modelProvider.selector.configure": "Configurer", + "modelProvider.selector.configureRequired": "Configuration requise", + "modelProvider.selector.creditsExhausted": "Crédits épuisés", + "modelProvider.selector.creditsExhaustedTip": "Vos AI credits ont été épuisés. Veuillez mettre à niveau votre forfait ou ajouter une API Key.", + "modelProvider.selector.disabled": "Désactivé", + "modelProvider.selector.discoverMoreInMarketplace": "Découvrir plus sur le Marketplace", "modelProvider.selector.emptySetting": "Veuillez aller dans les paramètres pour configurer", "modelProvider.selector.emptyTip": "Aucun modèle disponible", + "modelProvider.selector.fromMarketplace": "Depuis le Marketplace", + "modelProvider.selector.incompatible": "Incompatible", + "modelProvider.selector.incompatibleTip": "Ce modèle n'est pas disponible dans la version actuelle. Veuillez sélectionner un autre modèle disponible.", + "modelProvider.selector.install": "Installer", + "modelProvider.selector.modelProviderSettings": "Paramètres du fournisseur de modèle", + "modelProvider.selector.noProviderConfigured": "Aucun fournisseur de modèle configuré", + "modelProvider.selector.noProviderConfiguredDesc": "Parcourez le Marketplace pour en installer un, ou configurez les fournisseurs dans les paramètres.", + "modelProvider.selector.onlyCompatibleModelsShown": "Seuls les modèles compatibles sont affichés", "modelProvider.selector.rerankTip": "Veuillez configurer le modèle Rerank", "modelProvider.selector.tip": "Ce modèle a été supprimé. Veuillez ajouter un modèle ou sélectionner un autre modèle.", "modelProvider.setupModelFirst": "Veuillez d'abord configurer votre modèle", diff --git a/web/i18n/fr-FR/plugin.json b/web/i18n/fr-FR/plugin.json index 79f43acb8e..cbea19646d 100644 --- a/web/i18n/fr-FR/plugin.json +++ b/web/i18n/fr-FR/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Supprimer le plugin", "action.deleteContentLeft": "Souhaitez-vous supprimer", "action.deleteContentRight": "Plug-in ?", + "action.deleteSuccess": "Plugin supprimé avec succès", "action.pluginInfo": "Informations sur le plugin", "action.usedInApps": "Ce plugin est utilisé dans les applications {{num}}.", "allCategories": "Toutes les catégories", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Installer", "detailPanel.operation.remove": "Enlever", "detailPanel.operation.update": "Mettre à jour", + "detailPanel.operation.updateTooltip": "Mettez à jour pour accéder aux derniers modèles.", "detailPanel.operation.viewDetail": "Voir les détails", "detailPanel.serviceOk": "Service OK", "detailPanel.strategyNum": "{{num}} {{strategy}} INCLUS", @@ -231,12 +233,18 @@ "source.local": "Fichier de package local", "source.marketplace": "Marché", "task.clearAll": "Effacer tout", + "task.errorMsg.github": "Ce plugin n'a pas pu être installé automatiquement.\nVeuillez l'installer depuis GitHub.", + "task.errorMsg.marketplace": "Ce plugin n'a pas pu être installé automatiquement.\nVeuillez l'installer depuis le Marketplace.", + "task.errorMsg.unknown": "Ce plugin n'a pas pu être installé.\nLa source du plugin n'a pas pu être identifiée.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} les plugins n’ont pas pu être installés, cliquez pour voir", + "task.installFromGithub": "Installer depuis GitHub", + "task.installFromMarketplace": "Installer depuis le Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} les plugins n’ont pas pu être installés", "task.installing": "Installation des plugins.", + "task.installingHint": "Installation en cours... Cela peut prendre quelques minutes.", "task.installingWithError": "Installation des plugins {{installingLength}}, succès de {{successLength}}, échec de {{errorLength}}", "task.installingWithSuccess": "Installation des plugins {{installingLength}}, succès de {{successLength}}.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/fr-FR/workflow.json b/web/i18n/fr-FR/workflow.json index b5f67f74b3..09d140445e 100644 --- a/web/i18n/fr-FR/workflow.json +++ b/web/i18n/fr-FR/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "mise à jour du flux de travail", "error.startNodeRequired": "Veuillez d'abord ajouter un nœud de départ avant {{operation}}", "errorMsg.authRequired": "Autorisation requise", + "errorMsg.configureModel": "Configurez un modèle", "errorMsg.fieldRequired": "{{field}} est requis", "errorMsg.fields.code": "Code", "errorMsg.fields.model": "Modèle", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Vision Variable", "errorMsg.invalidJson": "{{field}} est un JSON invalide", "errorMsg.invalidVariable": "Variable invalide", + "errorMsg.modelPluginNotInstalled": "Variable invalide. Configurez un modèle pour activer cette variable.", "errorMsg.noValidTool": "{{field}} aucun outil valide sélectionné", "errorMsg.rerankModelRequired": "Avant d’activer le modèle de reclassement, veuillez confirmer que le modèle a été correctement configuré dans les paramètres.", "errorMsg.startNodeRequired": "Veuillez d'abord ajouter un nœud de départ avant {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Taille de la fenêtre", "nodes.common.outputVars": "Variables de sortie", "nodes.common.pluginNotInstalled": "Le plugin n'est pas installé", + "nodes.common.pluginsNotInstalled": "{{count}} plugins non installés", "nodes.common.retry.maxRetries": "Nombre maximal de tentatives", "nodes.common.retry.ms": "ms", "nodes.common.retry.retries": "{{num}} Tentatives", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Morceaux", "nodes.knowledgeBase.chunksInputTip": "La variable d'entrée du nœud de la base de connaissances est Chunks. Le type de variable est un objet avec un schéma JSON spécifique qui doit être cohérent avec la structure de morceau sélectionnée.", "nodes.knowledgeBase.chunksVariableIsRequired": "La variable Chunks est requise", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key indisponible", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Crédits épuisés", + "nodes.knowledgeBase.embeddingModelIncompatible": "Incompatible", "nodes.knowledgeBase.embeddingModelIsInvalid": "Le modèle d'intégration est invalide", - "nodes.knowledgeBase.embeddingModelIsRequired": "Un modèle d'intégration est requis", + "nodes.knowledgeBase.embeddingModelIsRequired": "Un modèle d’intégration est requis", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Modèle d’embedding non configuré", "nodes.knowledgeBase.indexMethodIsRequired": "La méthode d’indexation est requise", + "nodes.knowledgeBase.notConfigured": "Non configuré", "nodes.knowledgeBase.rerankingModelIsInvalid": "Le modèle de rerank est invalide", "nodes.knowledgeBase.rerankingModelIsRequired": "Un modèle de rerankage est requis", "nodes.knowledgeBase.retrievalSettingIsRequired": "Le paramètre de récupération est requis", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Prend en charge uniquement Jinja2", "nodes.templateTransform.inputVars": "Variables de saisie", "nodes.templateTransform.outputVars.output": "Contenu transformé", + "nodes.tool.authorizationRequired": "Autorisation requise", "nodes.tool.authorize": "Autoriser", "nodes.tool.inputVars": "Variables de saisie", "nodes.tool.insertPlaceholder1": "Tapez ou appuyez", @@ -1062,10 +1071,12 @@ "panel.change": "Modifier", "panel.changeBlock": "Changer de nœud", "panel.checklist": "Liste de contrôle", + "panel.checklistDescription": "Résolvez les problèmes suivants avant de publier", "panel.checklistResolved": "Tous les problèmes ont été résolus", "panel.checklistTip": "Assurez-vous que tous les problèmes sont résolus avant de publier", "panel.createdBy": "Créé par", "panel.goTo": "Aller à", + "panel.goToFix": "Aller corriger", "panel.helpLink": "Aide", "panel.maximize": "Maximiser le Canvas", "panel.minimize": "Sortir du mode plein écran", diff --git a/web/i18n/hi-IN/app-debug.json b/web/i18n/hi-IN/app-debug.json index 97e733e167..f9d438f0b1 100644 --- a/web/i18n/hi-IN/app-debug.json +++ b/web/i18n/hi-IN/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "चालू करें", "inputs.title": "डिबग और पूर्वावलोकन", "inputs.userInputField": "उपयोगकर्ता इनपुट फ़ील्ड", + "manageModels": "मॉडल प्रबंधित करें", "modelConfig.modeType.chat": "चैट", "modelConfig.modeType.completion": "पूर्ण", "modelConfig.model": "मॉडल", "modelConfig.setTone": "प्रतिक्रियाओं की टोन सेट करें", "modelConfig.title": "मॉडल और पैरामीटर", + "noModelProviderConfigured": "कोई मॉडल प्रदाता कॉन्फ़िगर नहीं है", + "noModelProviderConfiguredTip": "शुरू करने के लिए एक मॉडल प्रदाता इंस्टॉल या कॉन्फ़िगर करें।", + "noModelSelected": "कोई मॉडल चयनित नहीं है", + "noModelSelectedTip": "जारी रखने के लिए ऊपर एक मॉडल कॉन्फ़िगर करें।", "noResult": "प्रदर्शन यहाँ होगा।", "notSetAPIKey.description": "एलएलएम प्रदाता कुंजी सेट नहीं की गई है, और डीबग करने से पहले इसे सेट करने की आवश्यकता है।", "notSetAPIKey.settingBtn": "सेटिंग्स पर जाएं", diff --git a/web/i18n/hi-IN/common.json b/web/i18n/hi-IN/common.json index 75ac76add8..e61c96ca45 100644 --- a/web/i18n/hi-IN/common.json +++ b/web/i18n/hi-IN/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "अअनधिकारित", "modelProvider.buyQuota": "कोटा खरीदें", "modelProvider.callTimes": "कॉल समय", + "modelProvider.card.aiCreditsInUse": "AI credits उपयोग में हैं", + "modelProvider.card.aiCreditsOption": "AI credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "API key आवश्यक है", + "modelProvider.card.apiKeyUnavailableFallback": "API Key अनुपलब्ध, AI credits का उपयोग हो रहा है", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "वापस स्विच करने के लिए अपना API key कॉन्फ़िगरेशन जाँचें", "modelProvider.card.buyQuota": "कोटा खरीदें", "modelProvider.card.callTimes": "कॉल समय", + "modelProvider.card.creditsExhaustedDescription": "कृपया अपना प्लान अपग्रेड करें या API key कॉन्फ़िगर करें", + "modelProvider.card.creditsExhaustedFallback": "AI credits समाप्त, API Key का उपयोग हो रहा है", + "modelProvider.card.creditsExhaustedFallbackDescription": "AI credit प्राथमिकता फिर से शुरू करने के लिए अपना प्लान अपग्रेड करें।", + "modelProvider.card.creditsExhaustedMessage": "AI credits समाप्त हो गए हैं", "modelProvider.card.modelAPI": "{{modelName}} मॉडल API कुंजी का उपयोग कर रहे हैं।", "modelProvider.card.modelNotSupported": "{{modelName}} मॉडल इंस्टॉल नहीं हैं।", "modelProvider.card.modelSupported": "{{modelName}} मॉडल इस कोटा का उपयोग कर रहे हैं।", + "modelProvider.card.noApiKeysDescription": "अपने स्वयं के मॉडल क्रेडेंशियल का उपयोग शुरू करने के लिए API key जोड़ें।", + "modelProvider.card.noApiKeysFallback": "कोई API key नहीं, AI credits का उपयोग हो रहा है", + "modelProvider.card.noApiKeysTitle": "अभी तक कोई API key कॉन्फ़िगर नहीं है", + "modelProvider.card.noAvailableUsage": "कोई उपलब्ध उपयोग नहीं", "modelProvider.card.onTrial": "परीक्षण पर", "modelProvider.card.paid": "भुगतान किया हुआ", "modelProvider.card.priorityUse": "प्राथमिकता उपयोग", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "API कुंजी निकालें", "modelProvider.card.tip": "संदेश क्रेडिट {{modelNames}} के मॉडल का समर्थन करते हैं। भुगतान किए गए कोटा को प्राथमिकता दी जाएगी। भुगतान किए गए कोटा के समाप्त होने के बाद मुफ्त कोटा का उपयोग किया जाएगा।", "modelProvider.card.tokens": "टोकन", + "modelProvider.card.unavailable": "अनुपलब्ध", + "modelProvider.card.upgradePlan": "अपना प्लान अपग्रेड करें", + "modelProvider.card.usageLabel": "उपयोग", + "modelProvider.card.usagePriority": "उपयोग प्राथमिकता", + "modelProvider.card.usagePriorityTip": "मॉडल चलाते समय पहले किस संसाधन का उपयोग करना है, यह सेट करें।", "modelProvider.collapse": "संक्षिप्त करें", "modelProvider.config": "कॉन्फ़िग", "modelProvider.configLoadBalancing": "लोड बैलेंसिंग कॉन्फ़िग करें", @@ -387,9 +406,11 @@ "modelProvider.model": "मॉडल", "modelProvider.modelAndParameters": "मॉडल और पैरामीटर", "modelProvider.modelHasBeenDeprecated": "यह मॉडल अप्रचलित हो गया है", + "modelProvider.modelSettings": "मॉडल सेटिंग्स", "modelProvider.models": "मॉडल्स", "modelProvider.modelsNum": "{{num}} मॉडल्स", "modelProvider.noModelFound": "{{model}} के लिए कोई मॉडल नहीं मिला", + "modelProvider.noneConfigured": "एप्लिकेशन चलाने के लिए एक डिफ़ॉल्ट सिस्टम मॉडल कॉन्फ़िगर करें", "modelProvider.notConfigured": "सिस्टम मॉडल को अभी पूरी तरह से कॉन्फ़िगर नहीं किया गया है, और कुछ कार्य उपलब्ध नहीं हो सकते हैं।", "modelProvider.parameters": "पैरामीटर", "modelProvider.parametersInvalidRemoved": "कुछ पैरामीटर अमान्य हैं और हटा दिए गए हैं", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "{{date}} को रीसेट करें", "modelProvider.searchModel": "खोज मॉडल", "modelProvider.selectModel": "अपने मॉडल का चयन करें", + "modelProvider.selector.aiCredits": "AI credits", + "modelProvider.selector.apiKeyUnavailable": "API Key अनुपलब्ध", + "modelProvider.selector.apiKeyUnavailableTip": "API key हटा दी गई है। कृपया एक नई API key कॉन्फ़िगर करें।", + "modelProvider.selector.configure": "कॉन्फ़िगर करें", + "modelProvider.selector.configureRequired": "कॉन्फ़िगरेशन आवश्यक", + "modelProvider.selector.creditsExhausted": "क्रेडिट समाप्त", + "modelProvider.selector.creditsExhaustedTip": "आपके AI credits समाप्त हो गए हैं। कृपया अपना प्लान अपग्रेड करें या API key जोड़ें।", + "modelProvider.selector.disabled": "अक्षम", + "modelProvider.selector.discoverMoreInMarketplace": "Marketplace में और खोजें", "modelProvider.selector.emptySetting": "कॉन्फ़िगर करने के लिए कृपया सेटिंग्स पर जाएं", "modelProvider.selector.emptyTip": "कोई उपलब्ध मॉडल नहीं", + "modelProvider.selector.fromMarketplace": "Marketplace से", + "modelProvider.selector.incompatible": "असंगत", + "modelProvider.selector.incompatibleTip": "यह मॉडल वर्तमान संस्करण में उपलब्ध नहीं है। कृपया कोई अन्य उपलब्ध मॉडल चुनें।", + "modelProvider.selector.install": "इंस्टॉल करें", + "modelProvider.selector.modelProviderSettings": "मॉडल प्रदाता सेटिंग्स", + "modelProvider.selector.noProviderConfigured": "कोई मॉडल प्रदाता कॉन्फ़िगर नहीं है", + "modelProvider.selector.noProviderConfiguredDesc": "इंस्टॉल करने के लिए Marketplace ब्राउज़ करें, या सेटिंग्स में प्रदाता कॉन्फ़िगर करें।", + "modelProvider.selector.onlyCompatibleModelsShown": "केवल संगत मॉडल दिखाए गए हैं", "modelProvider.selector.rerankTip": "कृपया रीरैंक मॉडल सेट करें", "modelProvider.selector.tip": "इस मॉडल को हटा दिया गया है। कृपया एक मॉडल जोड़ें या किसी अन्य मॉडल का चयन करें।", "modelProvider.setupModelFirst": "कृपया पहले अपना मॉडल सेट करें", diff --git a/web/i18n/hi-IN/plugin.json b/web/i18n/hi-IN/plugin.json index 5a16c2676e..8d3a1c2d1d 100644 --- a/web/i18n/hi-IN/plugin.json +++ b/web/i18n/hi-IN/plugin.json @@ -3,6 +3,7 @@ "action.delete": "प्लगइन हटाएं", "action.deleteContentLeft": "क्या आप हटाना चाहेंगे", "action.deleteContentRight": "प्लगइन?", + "action.deleteSuccess": "प्लगइन सफलतापूर्वक हटाया गया", "action.pluginInfo": "प्लगइन जानकारी", "action.usedInApps": "यह प्लगइन {{num}} ऐप्स में उपयोग किया जा रहा है।", "allCategories": "सभी श्रेणियाँ", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "स्थापित करें", "detailPanel.operation.remove": "हटाएं", "detailPanel.operation.update": "अपडेट", + "detailPanel.operation.updateTooltip": "नवीनतम मॉडल तक पहुँचने के लिए अपडेट करें।", "detailPanel.operation.viewDetail": "विवरण देखें", "detailPanel.serviceOk": "सेवा ठीक है", "detailPanel.strategyNum": "{{num}} {{strategy}} शामिल", @@ -231,12 +233,18 @@ "source.local": "स्थानीय पैकेज फ़ाइल", "source.marketplace": "बाजार", "task.clearAll": "सभी साफ करें", + "task.errorMsg.github": "यह प्लगइन स्वचालित रूप से इंस्टॉल नहीं हो सका।\nकृपया इसे GitHub से इंस्टॉल करें।", + "task.errorMsg.marketplace": "यह प्लगइन स्वचालित रूप से इंस्टॉल नहीं हो सका।\nकृपया इसे Marketplace से इंस्टॉल करें।", + "task.errorMsg.unknown": "यह प्लगइन इंस्टॉल नहीं हो सका।\nप्लगइन स्रोत की पहचान नहीं हो सकी।", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} प्लगइन्स स्थापित करने में विफल रहे, देखने के लिए क्लिक करें", + "task.installFromGithub": "GitHub से इंस्टॉल करें", + "task.installFromMarketplace": "Marketplace से इंस्टॉल करें", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} प्लगइन्स स्थापित करने में विफल रहे", "task.installing": "प्लगइन्स स्थापित कर रहे हैं।", + "task.installingHint": "इंस्टॉल हो रहा है... इसमें कुछ मिनट लग सकते हैं।", "task.installingWithError": "{{installingLength}} प्लगइन्स स्थापित कर रहे हैं, {{successLength}} सफल, {{errorLength}} विफल", "task.installingWithSuccess": "{{installingLength}} प्लगइन्स स्थापित कर रहे हैं, {{successLength}} सफल।", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/hi-IN/workflow.json b/web/i18n/hi-IN/workflow.json index 31a3d4627f..fb9256536e 100644 --- a/web/i18n/hi-IN/workflow.json +++ b/web/i18n/hi-IN/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "वर्कफ़्लो को अपडेट करना", "error.startNodeRequired": "कृपया {{operation}} से पहले एक प्रारंभ नोड जोड़ें", "errorMsg.authRequired": "प्राधिकरण आवश्यक है", + "errorMsg.configureModel": "एक मॉडल कॉन्फ़िगर करें", "errorMsg.fieldRequired": "{{field}} आवश्यक है", "errorMsg.fields.code": "कोड", "errorMsg.fields.model": "मॉडल", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "दृष्टि चर", "errorMsg.invalidJson": "{{field}} अमान्य JSON है", "errorMsg.invalidVariable": "अमान्य वेरिएबल", + "errorMsg.modelPluginNotInstalled": "अमान्य वेरिएबल। इस वेरिएबल को सक्षम करने के लिए एक मॉडल कॉन्फ़िगर करें।", "errorMsg.noValidTool": "{{field}} कोई मान्य उपकरण चयनित नहीं किया गया", "errorMsg.rerankModelRequired": "Rerank मॉडल चालू करने से पहले, कृपया पुष्टि करें कि मॉडल को सेटिंग्स में सफलतापूर्वक कॉन्फ़िगर किया गया है।", "errorMsg.startNodeRequired": "कृपया {{operation}} से पहले पहले एक स्टार्ट नोड जोड़ें", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "विंडो साइज", "nodes.common.outputVars": "आउटपुट वेरिएबल्स", "nodes.common.pluginNotInstalled": "प्लगइन इंस्टॉल नहीं है", + "nodes.common.pluginsNotInstalled": "{{count}} प्लगइन इंस्टॉल नहीं हैं", "nodes.common.retry.maxRetries": "अधिकतम पुनः प्रयास करता है", "nodes.common.retry.ms": "सुश्री", "nodes.common.retry.retries": "{{num}} पुनर्प्रयास", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "टुकड़े", "nodes.knowledgeBase.chunksInputTip": "ज्ञान आधार नोड का इनपुट वेरिएबल टुकड़े है। वेरिएबल प्रकार एक ऑब्जेक्ट है जिसमें एक विशेष JSON स्कीमा है जो चयनित चंक संरचना के साथ सुसंगत होना चाहिए।", "nodes.knowledgeBase.chunksVariableIsRequired": "टुकड़े चर आवश्यक है", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API key अनुपलब्ध", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "क्रेडिट समाप्त", + "nodes.knowledgeBase.embeddingModelIncompatible": "असंगत", "nodes.knowledgeBase.embeddingModelIsInvalid": "एम्बेडिंग मॉडल अमान्य है", "nodes.knowledgeBase.embeddingModelIsRequired": "एम्बेडिंग मॉडल आवश्यक है", + "nodes.knowledgeBase.embeddingModelNotConfigured": "एम्बेडिंग मॉडल कॉन्फ़िगर नहीं है", "nodes.knowledgeBase.indexMethodIsRequired": "सूची विधि आवश्यक है", + "nodes.knowledgeBase.notConfigured": "कॉन्फ़िगर नहीं है", "nodes.knowledgeBase.rerankingModelIsInvalid": "पुनः क्रमांकन मॉडल अमान्य है", "nodes.knowledgeBase.rerankingModelIsRequired": "पुनः क्रमांकन मॉडल की आवश्यकता है", "nodes.knowledgeBase.retrievalSettingIsRequired": "पुनप्राप्ति सेटिंग आवश्यक है", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "केवल Jinja2 का समर्थन करता है", "nodes.templateTransform.inputVars": "इनपुट वेरिएबल्स", "nodes.templateTransform.outputVars.output": "रूपांतरित सामग्री", + "nodes.tool.authorizationRequired": "प्राधिकरण आवश्यक", "nodes.tool.authorize": "अधिकृत करें", "nodes.tool.inputVars": "इनपुट वेरिएबल्स", "nodes.tool.insertPlaceholder1": "टाइप करें या दबाएँ", @@ -1062,10 +1071,12 @@ "panel.change": "बदलें", "panel.changeBlock": "नोड बदलें", "panel.checklist": "चेकलिस्ट", + "panel.checklistDescription": "प्रकाशित करने से पहले निम्नलिखित समस्याएँ हल करें", "panel.checklistResolved": "सभी समस्याएं हल हो गई हैं", "panel.checklistTip": "प्रकाशित करने से पहले सुनिश्चित करें कि सभी समस्याएं हल हो गई हैं", "panel.createdBy": "द्वारा बनाया गया ", "panel.goTo": "जाओ", + "panel.goToFix": "ठीक करने जाएँ", "panel.helpLink": "सहायता", "panel.maximize": "कैनवास का अधिकतम लाभ उठाएँ", "panel.minimize": "पूर्ण स्क्रीन से बाहर निकलें", diff --git a/web/i18n/id-ID/app-debug.json b/web/i18n/id-ID/app-debug.json index 21651349ba..feb0f0a296 100644 --- a/web/i18n/id-ID/app-debug.json +++ b/web/i18n/id-ID/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "Jalankan", "inputs.title": "Debug & Pratinjau", "inputs.userInputField": "Bidang Input Pengguna", + "manageModels": "Kelola model", "modelConfig.modeType.chat": "Mengobrol", "modelConfig.modeType.completion": "Lengkap", "modelConfig.model": "Pola", "modelConfig.setTone": "Atur nada respons", "modelConfig.title": "Model dan Parameter", + "noModelProviderConfigured": "Belum ada penyedia model yang dikonfigurasi", + "noModelProviderConfiguredTip": "Instal atau konfigurasi penyedia model untuk memulai.", + "noModelSelected": "Belum ada model yang dipilih", + "noModelSelectedTip": "Konfigurasi model di atas untuk melanjutkan.", "noResult": "Output akan ditampilkan di sini.", "notSetAPIKey.description": "Kunci penyedia LLM belum diatur, dan perlu diatur sebelum debugging.", "notSetAPIKey.settingBtn": "Buka pengaturan", diff --git a/web/i18n/id-ID/common.json b/web/i18n/id-ID/common.json index 90531af712..51cd429992 100644 --- a/web/i18n/id-ID/common.json +++ b/web/i18n/id-ID/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Sah", "modelProvider.buyQuota": "Beli Kuota", "modelProvider.callTimes": "Waktu panggilan", + "modelProvider.card.aiCreditsInUse": "Kredit AI sedang digunakan", + "modelProvider.card.aiCreditsOption": "Kredit AI", + "modelProvider.card.apiKeyOption": "Kunci API", + "modelProvider.card.apiKeyRequired": "Kunci API diperlukan", + "modelProvider.card.apiKeyUnavailableFallback": "Kunci API tidak tersedia, kini menggunakan kredit AI", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Periksa konfigurasi kunci API Anda untuk beralih kembali", "modelProvider.card.buyQuota": "Beli Kuota", "modelProvider.card.callTimes": "Waktu panggilan", + "modelProvider.card.creditsExhaustedDescription": "Silakan tingkatkan paket Anda atau konfigurasikan kunci API", + "modelProvider.card.creditsExhaustedFallback": "Kredit AI habis, kini menggunakan kunci API", + "modelProvider.card.creditsExhaustedFallbackDescription": "Tingkatkan paket Anda untuk melanjutkan prioritas kredit AI.", + "modelProvider.card.creditsExhaustedMessage": "Kredit AI telah habis", "modelProvider.card.modelAPI": "Model {{modelName}} menggunakan Kunci API.", "modelProvider.card.modelNotSupported": "Model {{modelName}} tidak terpasang.", "modelProvider.card.modelSupported": "Model {{modelName}} menggunakan kuota ini.", + "modelProvider.card.noApiKeysDescription": "Tambahkan kunci API untuk mulai menggunakan kredensial model Anda sendiri.", + "modelProvider.card.noApiKeysFallback": "Tidak ada kunci API, menggunakan kredit AI sebagai gantinya", + "modelProvider.card.noApiKeysTitle": "Belum ada kunci API yang dikonfigurasi", + "modelProvider.card.noAvailableUsage": "Tidak ada penggunaan yang tersedia", "modelProvider.card.onTrial": "Sedang Diadili", "modelProvider.card.paid": "Dibayar", "modelProvider.card.priorityUse": "Penggunaan prioritas", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Menghapus Kunci API", "modelProvider.card.tip": "Kredit pesan mendukung model dari {{modelNames}}. Prioritas akan diberikan pada kuota yang dibayarkan. Kuota gratis akan digunakan setelah kuota yang dibayarkan habis.", "modelProvider.card.tokens": "Token", + "modelProvider.card.unavailable": "Tidak tersedia", + "modelProvider.card.upgradePlan": "tingkatkan paket Anda", + "modelProvider.card.usageLabel": "Penggunaan", + "modelProvider.card.usagePriority": "Prioritas Penggunaan", + "modelProvider.card.usagePriorityTip": "Tentukan sumber daya mana yang digunakan terlebih dahulu saat menjalankan model.", "modelProvider.collapse": "Roboh", "modelProvider.config": "Konfigurasi", "modelProvider.configLoadBalancing": "Penyeimbangan Beban Konfigurasi", @@ -387,9 +406,11 @@ "modelProvider.model": "Pola", "modelProvider.modelAndParameters": "Model dan Parameter", "modelProvider.modelHasBeenDeprecated": "Model ini tidak digunakan lagi", + "modelProvider.modelSettings": "Pengaturan Model", "modelProvider.models": "Model", "modelProvider.modelsNum": "Model {{num}}", "modelProvider.noModelFound": "Tidak ditemukan model untuk {{model}}", + "modelProvider.noneConfigured": "Konfigurasikan model sistem default untuk menjalankan aplikasi", "modelProvider.notConfigured": "Model sistem belum sepenuhnya dikonfigurasi", "modelProvider.parameters": "PARAMETER", "modelProvider.parametersInvalidRemoved": "Beberapa parameter tidak valid dan telah dihapus", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Setel ulang pada {{date}}", "modelProvider.searchModel": "Model pencarian", "modelProvider.selectModel": "Pilih model Anda", + "modelProvider.selector.aiCredits": "Kredit AI", + "modelProvider.selector.apiKeyUnavailable": "Kunci API tidak tersedia", + "modelProvider.selector.apiKeyUnavailableTip": "Kunci API telah dihapus. Silakan konfigurasikan kunci API baru.", + "modelProvider.selector.configure": "Konfigurasikan", + "modelProvider.selector.configureRequired": "Konfigurasi diperlukan", + "modelProvider.selector.creditsExhausted": "Kredit habis", + "modelProvider.selector.creditsExhaustedTip": "Kredit AI Anda telah habis. Silakan tingkatkan paket Anda atau tambahkan kunci API.", + "modelProvider.selector.disabled": "Dinonaktifkan", + "modelProvider.selector.discoverMoreInMarketplace": "Temukan lebih banyak di Marketplace", "modelProvider.selector.emptySetting": "Silakan buka pengaturan untuk mengonfigurasi", "modelProvider.selector.emptyTip": "Tidak ada model yang tersedia", + "modelProvider.selector.fromMarketplace": "Dari Marketplace", + "modelProvider.selector.incompatible": "Tidak kompatibel", + "modelProvider.selector.incompatibleTip": "Model ini tidak tersedia dalam versi saat ini. Silakan pilih model lain yang tersedia.", + "modelProvider.selector.install": "Instal", + "modelProvider.selector.modelProviderSettings": "Pengaturan Penyedia Model", + "modelProvider.selector.noProviderConfigured": "Belum ada penyedia model yang dikonfigurasi", + "modelProvider.selector.noProviderConfiguredDesc": "Jelajahi Marketplace untuk menginstal, atau konfigurasikan penyedia di pengaturan.", + "modelProvider.selector.onlyCompatibleModelsShown": "Hanya model yang kompatibel yang ditampilkan", "modelProvider.selector.rerankTip": "Silakan atur model Rerank", "modelProvider.selector.tip": "Model ini telah dihapus. Silakan tambahkan model atau pilih model lain.", "modelProvider.setupModelFirst": "Silakan atur model Anda terlebih dahulu", diff --git a/web/i18n/id-ID/plugin.json b/web/i18n/id-ID/plugin.json index 486de762f8..6b9c8f51af 100644 --- a/web/i18n/id-ID/plugin.json +++ b/web/i18n/id-ID/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Hapus plugin", "action.deleteContentLeft": "Apakah Anda ingin menghapus", "action.deleteContentRight": "Plugin?", + "action.deleteSuccess": "Plugin berhasil dihapus", "action.pluginInfo": "Info plugin", "action.usedInApps": "Plugin ini digunakan di {{num}} aplikasi.", "allCategories": "Semua Kategori", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Pasang", "detailPanel.operation.remove": "Hapus", "detailPanel.operation.update": "Pemutakhiran", + "detailPanel.operation.updateTooltip": "Perbarui untuk mengakses model terbaru.", "detailPanel.operation.viewDetail": "Lihat Detail", "detailPanel.serviceOk": "Layanan OK", "detailPanel.strategyNum": "{{num}} {{strategy}} TERMASUK", @@ -231,12 +233,18 @@ "source.local": "File Paket Lokal", "source.marketplace": "Pasar", "task.clearAll": "Hapus semua", + "task.errorMsg.github": "Plugin ini tidak terinstal secara otomatis.\nSilakan instal dari GitHub.", + "task.errorMsg.marketplace": "Plugin ini tidak terinstal secara otomatis.\nSilakan instal dari Marketplace.", + "task.errorMsg.unknown": "Plugin ini tidak terinstal.\nSumber plugin tidak dapat diidentifikasi.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "Gagal menginstal plugin {{errorLength}}, klik untuk melihat", + "task.installFromGithub": "Instal dari GitHub", + "task.installFromMarketplace": "Instal dari Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "Gagal menginstal {{errorLength}} plugin", "task.installing": "Memasang plugin.", + "task.installingHint": "Menginstal... Ini mungkin memerlukan beberapa menit.", "task.installingWithError": "Memasang {{installingLength}} plugin, {{successLength}} berhasil, {{errorLength}} gagal", "task.installingWithSuccess": "Memasang plugin {{installingLength}}, {{successLength}} berhasil.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/id-ID/workflow.json b/web/i18n/id-ID/workflow.json index 4f8e01b7f1..76e80be7d7 100644 --- a/web/i18n/id-ID/workflow.json +++ b/web/i18n/id-ID/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "memperbarui alur kerja", "error.startNodeRequired": "Silakan tambahkan node awal terlebih dahulu sebelum {{operation}}", "errorMsg.authRequired": "Otorisasi diperlukan", + "errorMsg.configureModel": "Konfigurasikan model", "errorMsg.fieldRequired": "{{field}} wajib diisi", "errorMsg.fields.code": "Kode", "errorMsg.fields.model": "Model", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Variabel Penglihatan", "errorMsg.invalidJson": "{{field}} adalah JSON yang tidak valid", "errorMsg.invalidVariable": "Variabel tidak valid", + "errorMsg.modelPluginNotInstalled": "Variabel tidak valid. Konfigurasikan model untuk mengaktifkan variabel ini.", "errorMsg.noValidTool": "{{field}} tidak ada alat yang valid dipilih", "errorMsg.rerankModelRequired": "Model Rerank yang dikonfigurasi diperlukan", "errorMsg.startNodeRequired": "Silakan tambahkan node awal terlebih dahulu sebelum {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Ukuran Jendela", "nodes.common.outputVars": "Variabel Keluaran", "nodes.common.pluginNotInstalled": "Plugin tidak terpasang", + "nodes.common.pluginsNotInstalled": "{{count}} plugin tidak terpasang", "nodes.common.retry.maxRetries": "percobaan ulang maks", "nodes.common.retry.ms": "Ms", "nodes.common.retry.retries": "{{num}} Percobaan Ulang", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Potongan", "nodes.knowledgeBase.chunksInputTip": "Variabel input dari node basis pengetahuan adalah Chunks. Tipe variabel adalah objek dengan Skema JSON tertentu yang harus konsisten dengan struktur chunk yang dipilih.", "nodes.knowledgeBase.chunksVariableIsRequired": "Variabel Chunks diperlukan", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "Kunci API tidak tersedia", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Kredit habis", + "nodes.knowledgeBase.embeddingModelIncompatible": "Tidak kompatibel", "nodes.knowledgeBase.embeddingModelIsInvalid": "Model embedding tidak valid", "nodes.knowledgeBase.embeddingModelIsRequired": "Model embedding diperlukan", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Model embedding belum dikonfigurasi", "nodes.knowledgeBase.indexMethodIsRequired": "Metode indeks diperlukan", + "nodes.knowledgeBase.notConfigured": "Belum dikonfigurasi", "nodes.knowledgeBase.rerankingModelIsInvalid": "Model reranking tidak valid", "nodes.knowledgeBase.rerankingModelIsRequired": "Model reranking diperlukan", "nodes.knowledgeBase.retrievalSettingIsRequired": "Pengaturan pengambilan diperlukan", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Hanya mendukung Jinja2", "nodes.templateTransform.inputVars": "Variabel Masukan", "nodes.templateTransform.outputVars.output": "Konten yang diubah", + "nodes.tool.authorizationRequired": "Otorisasi diperlukan", "nodes.tool.authorize": "Otorisasi", "nodes.tool.inputVars": "Variabel Masukan", "nodes.tool.insertPlaceholder1": "Ketik atau tekan", @@ -1062,10 +1071,12 @@ "panel.change": "Ubah", "panel.changeBlock": "Ubah Node", "panel.checklist": "Checklist", + "panel.checklistDescription": "Selesaikan masalah berikut sebelum menerbitkan", "panel.checklistResolved": "Semua masalah terselesaikan", "panel.checklistTip": "Pastikan semua masalah diselesaikan sebelum dipublikasikan", "panel.createdBy": "Dibuat oleh", "panel.goTo": "Pergi ke", + "panel.goToFix": "Pergi ke perbaikan", "panel.helpLink": "Docs", "panel.maximize": "Maksimalkan Kanvas", "panel.minimize": "Keluar dari Layar Penuh", diff --git a/web/i18n/it-IT/app-debug.json b/web/i18n/it-IT/app-debug.json index 94e4c00894..b9aad1ff6b 100644 --- a/web/i18n/it-IT/app-debug.json +++ b/web/i18n/it-IT/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "ESEGUI", "inputs.title": "Debug e Anteprima", "inputs.userInputField": "Campo Input Utente", + "manageModels": "Gestisci modelli", "modelConfig.modeType.chat": "Chat", "modelConfig.modeType.completion": "Completamento", "modelConfig.model": "Modello", "modelConfig.setTone": "Imposta tono delle risposte", "modelConfig.title": "Modello e Parametri", + "noModelProviderConfigured": "Nessun fornitore di modelli configurato", + "noModelProviderConfiguredTip": "Installa o configura un fornitore di modelli per iniziare.", + "noModelSelected": "Nessun modello selezionato", + "noModelSelectedTip": "Configura un modello sopra per continuare.", "noResult": "L'output verrà visualizzato qui.", "notSetAPIKey.description": "La chiave del provider LLM non è stata impostata e deve essere impostata prima del debug.", "notSetAPIKey.settingBtn": "Vai alle impostazioni", diff --git a/web/i18n/it-IT/common.json b/web/i18n/it-IT/common.json index d7b57d6d08..283c090ea8 100644 --- a/web/i18n/it-IT/common.json +++ b/web/i18n/it-IT/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Non autorizzato", "modelProvider.buyQuota": "Acquista Quota", "modelProvider.callTimes": "Numero di chiamate", + "modelProvider.card.aiCreditsInUse": "Crediti AI in uso", + "modelProvider.card.aiCreditsOption": "Crediti AI", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "API Key richiesta", + "modelProvider.card.apiKeyUnavailableFallback": "API Key non disponibile, utilizzo dei crediti AI in corso", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Controlla la configurazione della tua API Key per tornare indietro", "modelProvider.card.buyQuota": "Acquista Quota", "modelProvider.card.callTimes": "Numero di chiamate", + "modelProvider.card.creditsExhaustedDescription": "Aggiorna il tuo piano o configura una API Key", + "modelProvider.card.creditsExhaustedFallback": "Crediti AI esauriti, utilizzo della API Key in corso", + "modelProvider.card.creditsExhaustedFallbackDescription": "Aggiorna il tuo piano per ripristinare la priorità dei crediti AI.", + "modelProvider.card.creditsExhaustedMessage": "I crediti AI sono esauriti", "modelProvider.card.modelAPI": "I modelli {{modelName}} stanno utilizzando la chiave API.", "modelProvider.card.modelNotSupported": "I modelli {{modelName}} non sono installati.", "modelProvider.card.modelSupported": "I modelli {{modelName}} stanno utilizzando questa quota.", + "modelProvider.card.noApiKeysDescription": "Aggiungi una API Key per iniziare a usare le tue credenziali del modello.", + "modelProvider.card.noApiKeysFallback": "Nessuna API Key, utilizzo dei crediti AI in corso", + "modelProvider.card.noApiKeysTitle": "Nessuna API Key ancora configurata", + "modelProvider.card.noAvailableUsage": "Nessun utilizzo disponibile", "modelProvider.card.onTrial": "In Prova", "modelProvider.card.paid": "Pagato", "modelProvider.card.priorityUse": "Uso prioritario", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Rimuovi API Key", "modelProvider.card.tip": "I crediti di messaggi supportano modelli di {{modelNames}}. Verrà data priorità alla quota pagata. La quota gratuita sarà utilizzata dopo l'esaurimento della quota pagata.", "modelProvider.card.tokens": "Token", + "modelProvider.card.unavailable": "Non disponibile", + "modelProvider.card.upgradePlan": "aggiorna il tuo piano", + "modelProvider.card.usageLabel": "Utilizzo", + "modelProvider.card.usagePriority": "Priorità di utilizzo", + "modelProvider.card.usagePriorityTip": "Imposta quale risorsa utilizzare per prima durante l'esecuzione dei modelli.", "modelProvider.collapse": "Comprimi", "modelProvider.config": "Configura", "modelProvider.configLoadBalancing": "Configura Bilanciamento del Carico", @@ -387,9 +406,11 @@ "modelProvider.model": "Modello", "modelProvider.modelAndParameters": "Modello e Parametri", "modelProvider.modelHasBeenDeprecated": "Questo modello è stato deprecato", + "modelProvider.modelSettings": "Impostazioni modello", "modelProvider.models": "Modelli", "modelProvider.modelsNum": "{{num}} Modelli", "modelProvider.noModelFound": "Nessun modello trovato per {{model}}", + "modelProvider.noneConfigured": "Configura un modello di sistema predefinito per eseguire le applicazioni", "modelProvider.notConfigured": "Il modello di sistema non è ancora stato completamente configurato e alcune funzioni potrebbero non essere disponibili.", "modelProvider.parameters": "PARAMETRI", "modelProvider.parametersInvalidRemoved": "Alcuni parametri non sono validi e sono stati rimossi.", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Ripristina il {{date}}", "modelProvider.searchModel": "Modello di ricerca", "modelProvider.selectModel": "Seleziona il tuo modello", + "modelProvider.selector.aiCredits": "Crediti AI", + "modelProvider.selector.apiKeyUnavailable": "API Key non disponibile", + "modelProvider.selector.apiKeyUnavailableTip": "La API Key è stata rimossa. Configura una nuova API Key.", + "modelProvider.selector.configure": "Configura", + "modelProvider.selector.configureRequired": "Configurazione richiesta", + "modelProvider.selector.creditsExhausted": "Crediti esauriti", + "modelProvider.selector.creditsExhaustedTip": "I tuoi crediti AI sono esauriti. Aggiorna il tuo piano o aggiungi una API Key.", + "modelProvider.selector.disabled": "Disabilitato", + "modelProvider.selector.discoverMoreInMarketplace": "Scopri di più nel Marketplace", "modelProvider.selector.emptySetting": "Per favore vai alle impostazioni per configurare", "modelProvider.selector.emptyTip": "Nessun modello disponibile", + "modelProvider.selector.fromMarketplace": "Dal Marketplace", + "modelProvider.selector.incompatible": "Incompatibile", + "modelProvider.selector.incompatibleTip": "Questo modello non è disponibile nella versione corrente. Seleziona un altro modello disponibile.", + "modelProvider.selector.install": "Installa", + "modelProvider.selector.modelProviderSettings": "Impostazioni fornitore modelli", + "modelProvider.selector.noProviderConfigured": "Nessun fornitore di modelli configurato", + "modelProvider.selector.noProviderConfiguredDesc": "Cerca nel Marketplace per installarne uno o configura i fornitori nelle impostazioni.", + "modelProvider.selector.onlyCompatibleModelsShown": "Vengono mostrati solo i modelli compatibili", "modelProvider.selector.rerankTip": "Per favore, configura il modello di Rerank", "modelProvider.selector.tip": "Questo modello è stato rimosso. Per favore aggiungi un modello o seleziona un altro modello.", "modelProvider.setupModelFirst": "Per favore, configura prima il tuo modello", diff --git a/web/i18n/it-IT/plugin.json b/web/i18n/it-IT/plugin.json index 06a2dfa4ab..296aa31d54 100644 --- a/web/i18n/it-IT/plugin.json +++ b/web/i18n/it-IT/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Rimuovi plugin", "action.deleteContentLeft": "Vorresti rimuovere", "action.deleteContentRight": "plugin?", + "action.deleteSuccess": "Plugin rimosso con successo", "action.pluginInfo": "Informazioni sul plugin", "action.usedInApps": "Questo plugin viene utilizzato nelle app {{num}}.", "allCategories": "Tutte le categorie", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Installare", "detailPanel.operation.remove": "Togliere", "detailPanel.operation.update": "Aggiornare", + "detailPanel.operation.updateTooltip": "Aggiorna per accedere ai modelli più recenti.", "detailPanel.operation.viewDetail": "vedi dettagli", "detailPanel.serviceOk": "Servizio OK", "detailPanel.strategyNum": "{{num}} {{strategy}} INCLUSO", @@ -231,12 +233,18 @@ "source.local": "File del pacchetto locale", "source.marketplace": "Mercato", "task.clearAll": "Cancella tutto", + "task.errorMsg.github": "Impossibile installare automaticamente questo plugin.\nInstallalo da GitHub.", + "task.errorMsg.marketplace": "Impossibile installare automaticamente questo plugin.\nInstallalo dal Marketplace.", + "task.errorMsg.unknown": "Impossibile installare questo plugin.\nNon è stato possibile identificare la fonte del plugin.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "Impossibile installare i plugin {{errorLength}}, clicca per visualizzare", + "task.installFromGithub": "Installa da GitHub", + "task.installFromMarketplace": "Installa dal Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "Impossibile installare i plugin di {{errorLength}}", "task.installing": "Installazione dei plugin.", + "task.installingHint": "Installazione in corso... Potrebbe richiedere qualche minuto.", "task.installingWithError": "Installazione dei plugin {{installingLength}}, {{successLength}} successo, {{errorLength}} fallito", "task.installingWithSuccess": "Installazione dei plugin {{installingLength}}, {{successLength}} successo.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/it-IT/workflow.json b/web/i18n/it-IT/workflow.json index 0184b85009..519d7d7e2a 100644 --- a/web/i18n/it-IT/workflow.json +++ b/web/i18n/it-IT/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "aggiornamento del flusso di lavoro", "error.startNodeRequired": "Per favore aggiungi prima un nodo iniziale prima di {{operation}}", "errorMsg.authRequired": "È richiesta l'autorizzazione", + "errorMsg.configureModel": "Configura un modello", "errorMsg.fieldRequired": "{{field}} è richiesto", "errorMsg.fields.code": "Codice", "errorMsg.fields.model": "Modello", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Visione variabile", "errorMsg.invalidJson": "{{field}} è un JSON non valido", "errorMsg.invalidVariable": "Variabile non valida", + "errorMsg.modelPluginNotInstalled": "Variabile non valida. Configura un modello per abilitare questa variabile.", "errorMsg.noValidTool": "{{field}} nessuno strumento valido selezionato", "errorMsg.rerankModelRequired": "Prima di attivare il modello di reranking, conferma che il modello è stato configurato correttamente nelle impostazioni.", "errorMsg.startNodeRequired": "Per favore aggiungi prima un nodo iniziale prima di {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Dimensione Finestra", "nodes.common.outputVars": "Variabili di Output", "nodes.common.pluginNotInstalled": "Il plugin non è installato", + "nodes.common.pluginsNotInstalled": "{{count}} plugin non installati", "nodes.common.retry.maxRetries": "Numero massimo di tentativi", "nodes.common.retry.ms": "ms", "nodes.common.retry.retries": "{{num}} Tentativi", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Pezzetti", "nodes.knowledgeBase.chunksInputTip": "La variabile di input del nodo della base di conoscenza è Chunks. Il tipo di variabile è un oggetto con uno specifico schema JSON che deve essere coerente con la struttura del chunk selezionato.", "nodes.knowledgeBase.chunksVariableIsRequired": "La variabile Chunks è richiesta", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key non disponibile", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Crediti esauriti", + "nodes.knowledgeBase.embeddingModelIncompatible": "Incompatibile", "nodes.knowledgeBase.embeddingModelIsInvalid": "Il modello di embedding non è valido", "nodes.knowledgeBase.embeddingModelIsRequired": "È necessario un modello di embedding", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Modello di embedding non configurato", "nodes.knowledgeBase.indexMethodIsRequired": "È necessario il metodo dell'indice", + "nodes.knowledgeBase.notConfigured": "Non configurato", "nodes.knowledgeBase.rerankingModelIsInvalid": "Il modello di riorganizzazione è non valido", "nodes.knowledgeBase.rerankingModelIsRequired": "È richiesto un modello di riordinamento", "nodes.knowledgeBase.retrievalSettingIsRequired": "È richiesta l'impostazione di recupero", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Supporta solo Jinja2", "nodes.templateTransform.inputVars": "Variabili di Input", "nodes.templateTransform.outputVars.output": "Contenuto trasformato", + "nodes.tool.authorizationRequired": "Autorizzazione richiesta", "nodes.tool.authorize": "Autorizza", "nodes.tool.inputVars": "Variabili di Input", "nodes.tool.insertPlaceholder1": "Digita o premi", @@ -1062,10 +1071,12 @@ "panel.change": "Cambia", "panel.changeBlock": "Cambia Nodo", "panel.checklist": "Checklist", + "panel.checklistDescription": "Risolvi i seguenti problemi prima della pubblicazione", "panel.checklistResolved": "Tutti i problemi sono risolti", "panel.checklistTip": "Assicurati che tutti i problemi siano risolti prima di pubblicare", "panel.createdBy": "Creato da ", "panel.goTo": "Vai a", + "panel.goToFix": "Vai a correggere", "panel.helpLink": "Aiuto", "panel.maximize": "Massimizza Canvas", "panel.minimize": "Esci dalla modalità schermo intero", diff --git a/web/i18n/ja-JP/common.json b/web/i18n/ja-JP/common.json index 98fe556ecf..a65d8e933c 100644 --- a/web/i18n/ja-JP/common.json +++ b/web/i18n/ja-JP/common.json @@ -358,6 +358,7 @@ "modelProvider.card.noApiKeysDescription": "独自のモデル認証情報を使用するには、API キーを追加してください。", "modelProvider.card.noApiKeysFallback": "API キーが未設定のため、AI クレジットを使用しています", "modelProvider.card.noApiKeysTitle": "API キーはまだ設定されていません", + "modelProvider.card.noAvailableUsage": "利用可能な使用量がありません", "modelProvider.card.onTrial": "トライアル中", "modelProvider.card.paid": "有料", "modelProvider.card.priorityUse": "優先利用", @@ -366,6 +367,9 @@ "modelProvider.card.removeKey": "API キーを削除", "modelProvider.card.tip": "メッセージ枠は{{modelNames}}のモデルを使用することをサポートしています。無料枠は有料枠が使い果たされた後に消費されます。", "modelProvider.card.tokens": "トークン", + "modelProvider.card.unavailable": "利用不可", + "modelProvider.card.upgradePlan": "プランをアップグレード", + "modelProvider.card.usageLabel": "使用量", "modelProvider.card.usagePriority": "使用優先順位", "modelProvider.card.usagePriorityTip": "モデル実行時に優先して使用するリソースを設定します。", "modelProvider.collapse": "折り畳み", @@ -402,9 +406,11 @@ "modelProvider.model": "モデル", "modelProvider.modelAndParameters": "モデルとパラメータ", "modelProvider.modelHasBeenDeprecated": "このモデルは廃止予定です", + "modelProvider.modelSettings": "モデル設定", "modelProvider.models": "モデル", "modelProvider.modelsNum": "{{num}}のモデル", "modelProvider.noModelFound": "{{model}}に対するモデルが見つかりません", + "modelProvider.noneConfigured": "アプリケーションを実行するためにデフォルトのシステムモデルを設定してください", "modelProvider.notConfigured": "システムモデルがまだ完全に設定されておらず、一部の機能が利用できない場合があります。", "modelProvider.parameters": "パラメータ", "modelProvider.parametersInvalidRemoved": "いくつかのパラメータが無効であり、削除されました。", @@ -421,14 +427,21 @@ "modelProvider.selector.aiCredits": "AI クレジット", "modelProvider.selector.apiKeyUnavailable": "API キーが利用できません", "modelProvider.selector.apiKeyUnavailableTip": "API キーは削除されました。新しい API キーを設定してください。", + "modelProvider.selector.configure": "設定", "modelProvider.selector.configureRequired": "設定が必要です", "modelProvider.selector.creditsExhausted": "クレジットを使い切りました", "modelProvider.selector.creditsExhaustedTip": "AI クレジットを使い切りました。プランをアップグレードするか、API キーを追加してください。", + "modelProvider.selector.disabled": "無効", + "modelProvider.selector.discoverMoreInMarketplace": "マーケットプレイスでもっと探す", "modelProvider.selector.emptySetting": "設定に移動して構成してください", "modelProvider.selector.emptyTip": "利用可能なモデルはありません", "modelProvider.selector.fromMarketplace": "マーケットプレイスから", "modelProvider.selector.incompatible": "非互換", "modelProvider.selector.incompatibleTip": "このモデルは現在のバージョンでは利用できません。別の利用可能なモデルを選択してください。", + "modelProvider.selector.install": "インストール", + "modelProvider.selector.modelProviderSettings": "モデルプロバイダー設定", + "modelProvider.selector.noProviderConfigured": "モデルプロバイダーが設定されていません", + "modelProvider.selector.noProviderConfiguredDesc": "マーケットプレイスでインストールするか、設定でプロバイダーを設定してください。", "modelProvider.selector.onlyCompatibleModelsShown": "互換性のあるモデルのみが表示されます", "modelProvider.selector.rerankTip": "Rerank モデルを設定してください", "modelProvider.selector.tip": "このモデルは削除されました。別のモデルを追加するか、別のモデルを選択してください。", diff --git a/web/i18n/ja-JP/plugin.json b/web/i18n/ja-JP/plugin.json index c51a0e4117..38645124fb 100644 --- a/web/i18n/ja-JP/plugin.json +++ b/web/i18n/ja-JP/plugin.json @@ -3,6 +3,7 @@ "action.delete": "プラグインを削除する", "action.deleteContentLeft": "削除しますか", "action.deleteContentRight": "プラグイン?", + "action.deleteSuccess": "プラグインが正常に削除されました", "action.pluginInfo": "プラグイン情報", "action.usedInApps": "このプラグインは{{num}}のアプリで使用されています。", "allCategories": "すべてのカテゴリ", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "インストール", "detailPanel.operation.remove": "削除", "detailPanel.operation.update": "更新", + "detailPanel.operation.updateTooltip": "最新のモデルにアクセスするために更新してください。", "detailPanel.operation.viewDetail": "詳細を見る", "detailPanel.serviceOk": "サービスは正常です", "detailPanel.strategyNum": "{{num}} {{strategy}} が含まれています", @@ -231,12 +233,18 @@ "source.local": "ローカルパッケージファイル", "source.marketplace": "マーケットプレイス", "task.clearAll": "すべてクリア", + "task.errorMsg.github": "このプラグインは自動インストールできませんでした。\nGitHub からインストールしてください。", + "task.errorMsg.marketplace": "このプラグインは自動インストールできませんでした。\nマーケットプレイスからインストールしてください。", + "task.errorMsg.unknown": "このプラグインをインストールできませんでした。\nプラグインのソースを特定できませんでした。", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} プラグインのインストールに失敗しました。表示するにはクリックしてください。", + "task.installFromGithub": "GitHub からインストール", + "task.installFromMarketplace": "マーケットプレイスからインストール", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} プラグインのインストールに失敗しました", "task.installing": "プラグインをインストール中。", + "task.installingHint": "インストール中...数分かかる場合があります。", "task.installingWithError": "{{installingLength}}個のプラグインをインストール中、{{successLength}}件成功、{{errorLength}}件失敗", "task.installingWithSuccess": "{{installingLength}}個のプラグインをインストール中、{{successLength}}個成功しました。", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/ja-JP/workflow.json b/web/i18n/ja-JP/workflow.json index a2fac68bad..b9b60d6e73 100644 --- a/web/i18n/ja-JP/workflow.json +++ b/web/i18n/ja-JP/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "ワークフロー更新", "error.startNodeRequired": "{{operation}}前に開始ノードを追加してください", "errorMsg.authRequired": "認証が必要です", + "errorMsg.configureModel": "モデルを設定してください", "errorMsg.fieldRequired": "{{field}} は必須です", "errorMsg.fields.code": "コード", "errorMsg.fields.model": "モデル", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "ビジョン変数", "errorMsg.invalidJson": "{{field}} は無効な JSON です", "errorMsg.invalidVariable": "無効な変数です", + "errorMsg.modelPluginNotInstalled": "無効な変数です。この変数を有効にするにはモデルを設定してください。", "errorMsg.noValidTool": "{{field}} に利用可能なツールがありません", "errorMsg.rerankModelRequired": "Rerank モデルが設定されていません", "errorMsg.startNodeRequired": "{{operation}}前に開始ノードを追加してください", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "メモリウィンドウサイズ", "nodes.common.outputVars": "出力変数", "nodes.common.pluginNotInstalled": "プラグインがインストールされていません", + "nodes.common.pluginsNotInstalled": "{{count}} 個のプラグインがインストールされていません", "nodes.common.retry.maxRetries": "最大試行回数", "nodes.common.retry.ms": "ミリ秒", "nodes.common.retry.retries": "再試行回数:{{num}}", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "チャンク", "nodes.knowledgeBase.chunksInputTip": "知識ベースノードの入力変数はチャンクです。変数のタイプは、選択されたチャンク構造と一貫性のある特定のJSONスキーマを持つオブジェクトです。", "nodes.knowledgeBase.chunksVariableIsRequired": "Chunks変数は必須です", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API キーが利用できません", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "クレジットを使い切りました", + "nodes.knowledgeBase.embeddingModelIncompatible": "非互換", "nodes.knowledgeBase.embeddingModelIsInvalid": "埋め込みモデルが無効です", "nodes.knowledgeBase.embeddingModelIsRequired": "埋め込みモデルが必要です", + "nodes.knowledgeBase.embeddingModelNotConfigured": "埋め込みモデルが設定されていません", "nodes.knowledgeBase.indexMethodIsRequired": "インデックスメソッドが必要です", + "nodes.knowledgeBase.notConfigured": "未設定", "nodes.knowledgeBase.rerankingModelIsInvalid": "リランキングモデルは無効です", "nodes.knowledgeBase.rerankingModelIsRequired": "再ランキングモデルが必要です", "nodes.knowledgeBase.retrievalSettingIsRequired": "リトリーバル設定が必要です", @@ -1063,10 +1071,12 @@ "panel.change": "変更", "panel.changeBlock": "ノード変更", "panel.checklist": "チェックリスト", + "panel.checklistDescription": "公開前に以下の問題を解決してください", "panel.checklistResolved": "全てのチェックが完了しました", "panel.checklistTip": "公開前に全ての項目を確認してください", "panel.createdBy": "作成者", "panel.goTo": "移動", + "panel.goToFix": "修正する", "panel.helpLink": "ドキュメントを見る", "panel.maximize": "キャンバスを最大化する", "panel.minimize": "全画面を終了する", diff --git a/web/i18n/ko-KR/app-debug.json b/web/i18n/ko-KR/app-debug.json index 3cb295a3f7..1f9c1e2420 100644 --- a/web/i18n/ko-KR/app-debug.json +++ b/web/i18n/ko-KR/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "실행", "inputs.title": "디버그 및 미리보기", "inputs.userInputField": "사용자 입력 필드", + "manageModels": "모델 관리", "modelConfig.modeType.chat": "채팅", "modelConfig.modeType.completion": "완료", "modelConfig.model": "모델", "modelConfig.setTone": "응답 톤 설정", "modelConfig.title": "모델 및 매개변수", + "noModelProviderConfigured": "모델 공급자가 구성되지 않았습니다", + "noModelProviderConfiguredTip": "모델 공급자를 설치하거나 구성하여 시작하세요.", + "noModelSelected": "모델이 선택되지 않았습니다", + "noModelSelectedTip": "계속하려면 위에서 모델을 구성하세요.", "noResult": "출력이 여기에 표시됩니다.", "notSetAPIKey.description": "LLM 제공자 키가 설정되지 않았습니다. 디버깅하기 전에 설정해야 합니다.", "notSetAPIKey.settingBtn": "설정으로 이동", diff --git a/web/i18n/ko-KR/common.json b/web/i18n/ko-KR/common.json index 8fa351c517..e50a7c2428 100644 --- a/web/i18n/ko-KR/common.json +++ b/web/i18n/ko-KR/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "무단", "modelProvider.buyQuota": "할당량 구매", "modelProvider.callTimes": "호출 횟수", + "modelProvider.card.aiCreditsInUse": "AI 크레딧 사용 중", + "modelProvider.card.aiCreditsOption": "AI 크레딧", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "API 키 설정 필요", + "modelProvider.card.apiKeyUnavailableFallback": "API Key를 사용할 수 없어 AI 크레딧을 사용 중입니다", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "API Key 설정을 확인하여 다시 전환하세요", "modelProvider.card.buyQuota": "Buy Quota", "modelProvider.card.callTimes": "호출 횟수", + "modelProvider.card.creditsExhaustedDescription": "플랜을 업그레이드하거나 API 키를 설정하세요", + "modelProvider.card.creditsExhaustedFallback": "AI 크레딧이 소진되어 API Key를 사용 중입니다", + "modelProvider.card.creditsExhaustedFallbackDescription": "AI 크레딧 우선 사용을 재개하려면 플랜을 업그레이드하세요.", + "modelProvider.card.creditsExhaustedMessage": "AI 크레딧이 소진되었습니다", "modelProvider.card.modelAPI": "{{modelName}} 모델이 API 키를 사용하고 있습니다.", "modelProvider.card.modelNotSupported": "{{modelName}} 모델이 설치되지 않았습니다.", "modelProvider.card.modelSupported": "{{modelName}} 모델이 이 할당량을 사용하고 있습니다.", + "modelProvider.card.noApiKeysDescription": "자체 모델 자격 증명을 사용하려면 API 키를 추가하세요.", + "modelProvider.card.noApiKeysFallback": "API Key가 없어 AI 크레딧을 사용 중입니다", + "modelProvider.card.noApiKeysTitle": "아직 API 키가 구성되지 않았습니다", + "modelProvider.card.noAvailableUsage": "사용 가능한 용량 없음", "modelProvider.card.onTrial": "트라이얼 중", "modelProvider.card.paid": "유료", "modelProvider.card.priorityUse": "우선 사용", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "API 키 제거", "modelProvider.card.tip": "메시지 크레딧은 {{modelNames}}의 모델을 지원합니다. 유료 할당량에 우선순위가 부여됩니다. 무료 할당량은 유료 할당량이 소진된 후 사용됩니다.", "modelProvider.card.tokens": "토큰", + "modelProvider.card.unavailable": "사용 불가", + "modelProvider.card.upgradePlan": "플랜 업그레이드", + "modelProvider.card.usageLabel": "사용량", + "modelProvider.card.usagePriority": "사용 우선순위", + "modelProvider.card.usagePriorityTip": "모델 실행 시 우선 사용할 리소스를 설정합니다.", "modelProvider.collapse": "축소", "modelProvider.config": "설정", "modelProvider.configLoadBalancing": "Config 로드 밸런싱", @@ -387,9 +406,11 @@ "modelProvider.model": "모델", "modelProvider.modelAndParameters": "모델 및 매개변수", "modelProvider.modelHasBeenDeprecated": "이 모델은 더 이상 사용되지 않습니다", + "modelProvider.modelSettings": "모델 설정", "modelProvider.models": "모델", "modelProvider.modelsNum": "{{num}}개의 모델", "modelProvider.noModelFound": "{{model}}에 대한 모델을 찾을 수 없습니다", + "modelProvider.noneConfigured": "애플리케이션을 실행하려면 기본 시스템 모델을 구성하세요", "modelProvider.notConfigured": "시스템 모델이 아직 완전히 설정되지 않아 일부 기능을 사용할 수 없습니다.", "modelProvider.parameters": "매개변수", "modelProvider.parametersInvalidRemoved": "일부 매개변수가 유효하지 않아 제거되었습니다.", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "{{date}}에 재설정", "modelProvider.searchModel": "검색 모델", "modelProvider.selectModel": "모델 선택", + "modelProvider.selector.aiCredits": "AI 크레딧", + "modelProvider.selector.apiKeyUnavailable": "API Key 사용 불가", + "modelProvider.selector.apiKeyUnavailableTip": "API Key가 삭제되었습니다. 새 API Key를 설정하세요.", + "modelProvider.selector.configure": "구성", + "modelProvider.selector.configureRequired": "구성 필요", + "modelProvider.selector.creditsExhausted": "크레딧 소진", + "modelProvider.selector.creditsExhaustedTip": "AI 크레딧이 소진되었습니다. 플랜을 업그레이드하거나 API 키를 추가하세요.", + "modelProvider.selector.disabled": "비활성화됨", + "modelProvider.selector.discoverMoreInMarketplace": "마켓플레이스에서 더 찾아보기", "modelProvider.selector.emptySetting": "설정으로 이동하여 구성하세요", "modelProvider.selector.emptyTip": "사용 가능한 모델이 없습니다", + "modelProvider.selector.fromMarketplace": "마켓플레이스에서", + "modelProvider.selector.incompatible": "호환되지 않음", + "modelProvider.selector.incompatibleTip": "이 모델은 현재 버전에서 사용할 수 없습니다. 다른 사용 가능한 모델을 선택하세요.", + "modelProvider.selector.install": "설치", + "modelProvider.selector.modelProviderSettings": "모델 공급자 설정", + "modelProvider.selector.noProviderConfigured": "모델 공급자가 구성되지 않았습니다", + "modelProvider.selector.noProviderConfiguredDesc": "마켓플레이스에서 설치하거나 설정에서 공급자를 구성하세요.", + "modelProvider.selector.onlyCompatibleModelsShown": "호환되는 모델만 표시됩니다", "modelProvider.selector.rerankTip": "재랭크 모델을 설정하세요", "modelProvider.selector.tip": "이 모델은 삭제되었습니다. 다른 모델을 추가하거나 다른 모델을 선택하세요.", "modelProvider.setupModelFirst": "먼저 모델을 설정하세요", diff --git a/web/i18n/ko-KR/plugin.json b/web/i18n/ko-KR/plugin.json index 2f20a584c5..93b28beeb4 100644 --- a/web/i18n/ko-KR/plugin.json +++ b/web/i18n/ko-KR/plugin.json @@ -3,6 +3,7 @@ "action.delete": "플러그인 제거", "action.deleteContentLeft": "제거하시겠습니까?", "action.deleteContentRight": "플러그인?", + "action.deleteSuccess": "플러그인이 성공적으로 제거되었습니다", "action.pluginInfo": "플러그인 정보", "action.usedInApps": "이 플러그인은 {{num}}개의 앱에서 사용되고 있습니다.", "allCategories": "모든 카테고리", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "설치", "detailPanel.operation.remove": "제거", "detailPanel.operation.update": "업데이트", + "detailPanel.operation.updateTooltip": "최신 모델에 액세스하려면 업데이트하세요.", "detailPanel.operation.viewDetail": "자세히보기", "detailPanel.serviceOk": "서비스 정상", "detailPanel.strategyNum": "{{num}} {{strategy}} 포함", @@ -231,12 +233,18 @@ "source.local": "로컬 패키지 파일", "source.marketplace": "마켓", "task.clearAll": "모두 지우기", + "task.errorMsg.github": "이 플러그인을 자동으로 설치할 수 없었습니다.\nGitHub에서 설치하세요.", + "task.errorMsg.marketplace": "이 플러그인을 자동으로 설치할 수 없었습니다.\n마켓플레이스에서 설치하세요.", + "task.errorMsg.unknown": "이 플러그인을 설치할 수 없었습니다.\n플러그인 소스를 확인할 수 없습니다.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} 플러그인 설치 실패, 보려면 클릭하십시오.", + "task.installFromGithub": "GitHub에서 설치", + "task.installFromMarketplace": "마켓플레이스에서 설치", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} 플러그인 설치 실패", "task.installing": "플러그인 설치 중.", + "task.installingHint": "설치 중... 몇 분 정도 걸릴 수 있습니다.", "task.installingWithError": "{{installingLength}} 플러그인 설치, {{successLength}} 성공, {{errorLength}} 실패", "task.installingWithSuccess": "{{installingLength}} 플러그인 설치, {{successLength}} 성공.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/ko-KR/workflow.json b/web/i18n/ko-KR/workflow.json index b83afea149..24e8e634d3 100644 --- a/web/i18n/ko-KR/workflow.json +++ b/web/i18n/ko-KR/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "워크플로 업데이트", "error.startNodeRequired": "{{operation}} 전에 먼저 시작 노드를 추가해 주세요", "errorMsg.authRequired": "인증이 필요합니다", + "errorMsg.configureModel": "모델을 구성하세요", "errorMsg.fieldRequired": "{{field}}가 필요합니다", "errorMsg.fields.code": "코드", "errorMsg.fields.model": "모델", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "비전 변수", "errorMsg.invalidJson": "{{field}}는 잘못된 JSON 입니다", "errorMsg.invalidVariable": "잘못된 변수", + "errorMsg.modelPluginNotInstalled": "잘못된 변수입니다. 이 변수를 사용하려면 모델을 구성하세요.", "errorMsg.noValidTool": "{{field}} 유효한 도구가 선택되지 않았습니다.", "errorMsg.rerankModelRequired": "Rerank Model 을 켜기 전에 설정에서 모델이 성공적으로 구성되었는지 확인하십시오.", "errorMsg.startNodeRequired": "{{operation}} 전에 먼저 시작 노드를 추가해 주세요", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "창 크기", "nodes.common.outputVars": "출력 변수", "nodes.common.pluginNotInstalled": "플러그인이 설치되지 않았습니다", + "nodes.common.pluginsNotInstalled": "{{count}}개의 플러그인이 설치되지 않았습니다", "nodes.common.retry.maxRetries": "최대 재시도 횟수", "nodes.common.retry.ms": "ms", "nodes.common.retry.retries": "{{num}} 재시도", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "청크", "nodes.knowledgeBase.chunksInputTip": "지식 기반 노드의 입력 변수는 Chunks입니다. 변수 유형은 선택된 청크 구조와 일치해야 하는 특정 JSON 스키마를 가진 객체입니다.", "nodes.knowledgeBase.chunksVariableIsRequired": "Chunks 변수는 필수입니다", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key 사용 불가", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "크레딧 소진", + "nodes.knowledgeBase.embeddingModelIncompatible": "호환되지 않음", "nodes.knowledgeBase.embeddingModelIsInvalid": "임베딩 모델이 유효하지 않습니다", "nodes.knowledgeBase.embeddingModelIsRequired": "임베딩 모델이 필요합니다", + "nodes.knowledgeBase.embeddingModelNotConfigured": "임베딩 모델이 구성되지 않았습니다", "nodes.knowledgeBase.indexMethodIsRequired": "인덱스 메서드가 필요합니다.", + "nodes.knowledgeBase.notConfigured": "구성되지 않음", "nodes.knowledgeBase.rerankingModelIsInvalid": "재정렬 모델이 유효하지 않습니다", "nodes.knowledgeBase.rerankingModelIsRequired": "재순위 모델이 필요합니다", "nodes.knowledgeBase.retrievalSettingIsRequired": "검색 설정이 필요합니다.", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Jinja2 만 지원합니다", "nodes.templateTransform.inputVars": "입력 변수", "nodes.templateTransform.outputVars.output": "변환된 내용", + "nodes.tool.authorizationRequired": "인증 필요", "nodes.tool.authorize": "권한 부여", "nodes.tool.inputVars": "입력 변수", "nodes.tool.insertPlaceholder1": "타이프하거나 누르세요", @@ -1062,10 +1071,12 @@ "panel.change": "변경", "panel.changeBlock": "노드 변경", "panel.checklist": "체크리스트", + "panel.checklistDescription": "게시 전에 다음 문제를 해결하세요", "panel.checklistResolved": "모든 문제가 해결되었습니다", "panel.checklistTip": "게시하기 전에 모든 문제가 해결되었는지 확인하세요", "panel.createdBy": "작성자 ", "panel.goTo": "로 이동", + "panel.goToFix": "수정하러 가기", "panel.helpLink": "도움말 센터", "panel.maximize": "캔버스 전체 화면", "panel.minimize": "전체 화면 종료", diff --git a/web/i18n/nl-NL/app-debug.json b/web/i18n/nl-NL/app-debug.json index b667cfb052..7552cb7cc9 100644 --- a/web/i18n/nl-NL/app-debug.json +++ b/web/i18n/nl-NL/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "RUN", "inputs.title": "Debug & Preview", "inputs.userInputField": "User Input Field", + "manageModels": "Modellen beheren", "modelConfig.modeType.chat": "Chat", "modelConfig.modeType.completion": "Complete", "modelConfig.model": "Model", "modelConfig.setTone": "Set tone of responses", "modelConfig.title": "Model and Parameters", + "noModelProviderConfigured": "Geen modelprovider geconfigureerd", + "noModelProviderConfiguredTip": "Installeer of configureer een modelprovider om te beginnen.", + "noModelSelected": "Geen model geselecteerd", + "noModelSelectedTip": "Configureer hierboven een model om door te gaan.", "noResult": "Output will be displayed here.", "notSetAPIKey.description": "The LLM provider key has not been set, and it needs to be set before debugging.", "notSetAPIKey.settingBtn": "Go to settings", diff --git a/web/i18n/nl-NL/common.json b/web/i18n/nl-NL/common.json index 013f06b035..fb1b332a0c 100644 --- a/web/i18n/nl-NL/common.json +++ b/web/i18n/nl-NL/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Unauthorized", "modelProvider.buyQuota": "Buy Quota", "modelProvider.callTimes": "Call times", + "modelProvider.card.aiCreditsInUse": "AI-tegoeden in gebruik", + "modelProvider.card.aiCreditsOption": "AI-tegoeden", + "modelProvider.card.apiKeyOption": "API-sleutel", + "modelProvider.card.apiKeyRequired": "API-sleutel vereist", + "modelProvider.card.apiKeyUnavailableFallback": "API-sleutel niet beschikbaar, nu worden AI-tegoeden gebruikt", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Controleer uw API-sleutelconfiguratie om terug te schakelen", "modelProvider.card.buyQuota": "Buy Quota", "modelProvider.card.callTimes": "Call times", + "modelProvider.card.creditsExhaustedDescription": "Verhaag uw abonnement of configureer een API-sleutel", + "modelProvider.card.creditsExhaustedFallback": "AI-tegoeden uitgeput, nu wordt API-sleutel gebruikt", + "modelProvider.card.creditsExhaustedFallbackDescription": "Verhoog uw abonnement om AI-tegoed prioriteit te hervatten.", + "modelProvider.card.creditsExhaustedMessage": "AI-tegoeden zijn uitgeput", "modelProvider.card.modelAPI": "{{modelName}} models are using the API Key.", - "modelProvider.card.modelNotSupported": "{{modelName}} models are not installed.", - "modelProvider.card.modelSupported": "{{modelName}} models are using this quota.", + "modelProvider.card.modelNotSupported": "{{modelName}} niet geïnstalleerd", + "modelProvider.card.modelSupported": "{{modelName}} modellen gebruiken deze tegoeden.", + "modelProvider.card.noApiKeysDescription": "Voeg een API-sleutel toe om uw eigen modelgegevens te gebruiken.", + "modelProvider.card.noApiKeysFallback": "Geen API-sleutels, AI-tegoeden worden gebruikt", + "modelProvider.card.noApiKeysTitle": "Nog geen API-sleutels geconfigureerd", + "modelProvider.card.noAvailableUsage": "Geen beschikbaar gebruik", "modelProvider.card.onTrial": "On Trial", "modelProvider.card.paid": "Paid", "modelProvider.card.priorityUse": "Priority use", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Remove API Key", "modelProvider.card.tip": "Message Credits supports models from {{modelNames}}. Priority will be given to the paid quota. The free quota will be used after the paid quota is exhausted.", "modelProvider.card.tokens": "Tokens", + "modelProvider.card.unavailable": "Niet beschikbaar", + "modelProvider.card.upgradePlan": "verhaag uw abonnement", + "modelProvider.card.usageLabel": "Gebruik", + "modelProvider.card.usagePriority": "Gebruiksprioriteit", + "modelProvider.card.usagePriorityTip": "Stel in welke resource als eerste wordt gebruikt bij het uitvoeren van modellen.", "modelProvider.collapse": "Collapse", "modelProvider.config": "Config", "modelProvider.configLoadBalancing": "Config Load Balancing", @@ -387,9 +406,11 @@ "modelProvider.model": "Model", "modelProvider.modelAndParameters": "Model and Parameters", "modelProvider.modelHasBeenDeprecated": "This model has been deprecated", + "modelProvider.modelSettings": "Modelinstellingen", "modelProvider.models": "Models", "modelProvider.modelsNum": "{{num}} Models", "modelProvider.noModelFound": "No model found for {{model}}", + "modelProvider.noneConfigured": "Configureer een standaard systeemmodel om applicaties uit te voeren", "modelProvider.notConfigured": "The system model has not yet been fully configured", "modelProvider.parameters": "PARAMETERS", "modelProvider.parametersInvalidRemoved": "Some parameters are invalid and have been removed", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Reset on {{date}}", "modelProvider.searchModel": "Search model", "modelProvider.selectModel": "Select your model", + "modelProvider.selector.aiCredits": "AI-tegoeden", + "modelProvider.selector.apiKeyUnavailable": "API-sleutel niet beschikbaar", + "modelProvider.selector.apiKeyUnavailableTip": "De API-sleutel is verwijderd. Configureer een nieuwe API-sleutel.", + "modelProvider.selector.configure": "Configureren", + "modelProvider.selector.configureRequired": "Configuratie vereist", + "modelProvider.selector.creditsExhausted": "Tegoeden uitgeput", + "modelProvider.selector.creditsExhaustedTip": "Uw AI-tegoeden zijn uitgeput. Verhoog uw abonnement of voeg een API-sleutel toe.", + "modelProvider.selector.disabled": "Uitgeschakeld", + "modelProvider.selector.discoverMoreInMarketplace": "Ontdek meer in de Marketplace", "modelProvider.selector.emptySetting": "Please go to settings to configure", "modelProvider.selector.emptyTip": "No available models", + "modelProvider.selector.fromMarketplace": "Vanuit de Marketplace", + "modelProvider.selector.incompatible": "Incompatibel", + "modelProvider.selector.incompatibleTip": "Dit model is niet beschikbaar in de huidige versie. Selecteer een ander beschikbaar model.", + "modelProvider.selector.install": "Installeren", + "modelProvider.selector.modelProviderSettings": "Modelproviderinstellingen", + "modelProvider.selector.noProviderConfigured": "Geen modelprovider geconfigureerd", + "modelProvider.selector.noProviderConfiguredDesc": "Blader in de Marketplace om er een te installeren, of configureer providers in de instellingen.", + "modelProvider.selector.onlyCompatibleModelsShown": "Alleen compatibele modellen worden weergegeven", "modelProvider.selector.rerankTip": "Please set up the Rerank model", "modelProvider.selector.tip": "This model has been removed. Please add a model or select another model.", "modelProvider.setupModelFirst": "Please set up your model first", diff --git a/web/i18n/nl-NL/plugin.json b/web/i18n/nl-NL/plugin.json index c7f091a442..4e0301b5f5 100644 --- a/web/i18n/nl-NL/plugin.json +++ b/web/i18n/nl-NL/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Remove plugin", "action.deleteContentLeft": "Would you like to remove ", "action.deleteContentRight": " plugin?", + "action.deleteSuccess": "Plugin succesvol verwijderd", "action.pluginInfo": "Plugin info", "action.usedInApps": "This plugin is being used in {{num}} apps.", "allCategories": "All Categories", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Install", "detailPanel.operation.remove": "Remove", "detailPanel.operation.update": "Update", + "detailPanel.operation.updateTooltip": "Werk bij voor toegang tot de nieuwste modellen.", "detailPanel.operation.viewDetail": "View Detail", "detailPanel.serviceOk": "Service OK", "detailPanel.strategyNum": "{{num}} {{strategy}} INCLUDED", @@ -231,12 +233,18 @@ "source.local": "Local Package File", "source.marketplace": "Marketplace", "task.clearAll": "Clear all", + "task.errorMsg.github": "Deze plugin is niet automatisch geïnstalleerd.\nInstalleer het vanuit GitHub.", + "task.errorMsg.marketplace": "Deze plugin is niet automatisch geïnstalleerd.\nInstalleer het vanuit de Marketplace.", + "task.errorMsg.unknown": "Deze plugin is niet geïnstalleerd.\nDe pluginbron kon niet worden geïdentificeerd.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} plugins failed to install, click to view", + "task.installFromGithub": "Installeren vanuit GitHub", + "task.installFromMarketplace": "Installeren vanuit de Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} plugins failed to install", "task.installing": "Installing plugins", + "task.installingHint": "Installeren... Dit kan enkele minuten duren.", "task.installingWithError": "Installing {{installingLength}} plugins, {{successLength}} success, {{errorLength}} failed", "task.installingWithSuccess": "Installing {{installingLength}} plugins, {{successLength}} success.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/nl-NL/workflow.json b/web/i18n/nl-NL/workflow.json index 35555d0e7b..891df72387 100644 --- a/web/i18n/nl-NL/workflow.json +++ b/web/i18n/nl-NL/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "updating workflow", "error.startNodeRequired": "Please add a start node first before {{operation}}", "errorMsg.authRequired": "Authorization is required", + "errorMsg.configureModel": "Configureer een model", "errorMsg.fieldRequired": "{{field}} is required", "errorMsg.fields.code": "Code", "errorMsg.fields.model": "Model", @@ -313,7 +314,8 @@ "errorMsg.fields.variableValue": "Variable Value", "errorMsg.fields.visionVariable": "Vision Variable", "errorMsg.invalidJson": "{{field}} is invalid JSON", - "errorMsg.invalidVariable": "Invalid variable", + "errorMsg.invalidVariable": "Ongeldige variabele. Selecteer een bestaande variabele.", + "errorMsg.modelPluginNotInstalled": "Ongeldige variabele. Configureer een model om deze variabele in te schakelen.", "errorMsg.noValidTool": "{{field}} no valid tool selected", "errorMsg.rerankModelRequired": "A configured Rerank Model is required", "errorMsg.startNodeRequired": "Please add a start node first before {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Window Size", "nodes.common.outputVars": "Output Variables", "nodes.common.pluginNotInstalled": "Plugin is not installed", + "nodes.common.pluginsNotInstalled": "{{count}} plugins niet geïnstalleerd", "nodes.common.retry.maxRetries": "max retries", "nodes.common.retry.ms": "ms", "nodes.common.retry.retries": "{{num}} Retries", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Chunks", "nodes.knowledgeBase.chunksInputTip": "The input variable of the knowledge base node is Chunks. The variable type is an object with a specific JSON Schema which must be consistent with the selected chunk structure.", "nodes.knowledgeBase.chunksVariableIsRequired": "Chunks variable is required", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API-sleutel niet beschikbaar", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Tegoeden uitgeput", + "nodes.knowledgeBase.embeddingModelIncompatible": "Incompatibel", "nodes.knowledgeBase.embeddingModelIsInvalid": "Embedding model is invalid", "nodes.knowledgeBase.embeddingModelIsRequired": "Embedding model is required", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Inbeddingsmodel niet geconfigureerd", "nodes.knowledgeBase.indexMethodIsRequired": "Index method is required", + "nodes.knowledgeBase.notConfigured": "Niet geconfigureerd", "nodes.knowledgeBase.rerankingModelIsInvalid": "Reranking model is invalid", "nodes.knowledgeBase.rerankingModelIsRequired": "Reranking model is required", "nodes.knowledgeBase.retrievalSettingIsRequired": "Retrieval setting is required", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Only supports Jinja2", "nodes.templateTransform.inputVars": "Input Variables", "nodes.templateTransform.outputVars.output": "Transformed content", + "nodes.tool.authorizationRequired": "Autorisatie vereist", "nodes.tool.authorize": "Authorize", "nodes.tool.inputVars": "Input Variables", "nodes.tool.insertPlaceholder1": "Type or press", @@ -1062,10 +1071,12 @@ "panel.change": "Change", "panel.changeBlock": "Change Node", "panel.checklist": "Checklist", + "panel.checklistDescription": "Los de volgende problemen op vóór het publiceren", "panel.checklistResolved": "All issues are resolved", "panel.checklistTip": "Make sure all issues are resolved before publishing", "panel.createdBy": "Created By ", "panel.goTo": "Go to", + "panel.goToFix": "Ga naar repareren", "panel.helpLink": "View Docs", "panel.maximize": "Maximize Canvas", "panel.minimize": "Exit Full Screen", diff --git a/web/i18n/pl-PL/app-debug.json b/web/i18n/pl-PL/app-debug.json index 4d7c2bf2e5..d708c997e9 100644 --- a/web/i18n/pl-PL/app-debug.json +++ b/web/i18n/pl-PL/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "URUCHOM", "inputs.title": "Debugowanie i podgląd", "inputs.userInputField": "Pole wejściowe użytkownika", + "manageModels": "Zarządzaj modelami", "modelConfig.modeType.chat": "Czat", "modelConfig.modeType.completion": "Uzupełnienie", "modelConfig.model": "Model", "modelConfig.setTone": "Ustaw ton odpowiedzi", "modelConfig.title": "Model i parametry", + "noModelProviderConfigured": "Nie skonfigurowano dostawcy modeli", + "noModelProviderConfiguredTip": "Zainstaluj lub skonfiguruj dostawcę modeli, aby rozpocząć.", + "noModelSelected": "Nie wybrano modelu", + "noModelSelectedTip": "skonfiguruj model powyżej, aby kontynuować.", "noResult": "W tym miejscu zostaną wyświetlone dane wyjściowe.", "notSetAPIKey.description": "Klucz dostawcy LLM nie został ustawiony, musi zostać ustawiony przed debugowaniem.", "notSetAPIKey.settingBtn": "Przejdź do ustawień", diff --git a/web/i18n/pl-PL/common.json b/web/i18n/pl-PL/common.json index 029e9dd660..130950a57c 100644 --- a/web/i18n/pl-PL/common.json +++ b/web/i18n/pl-PL/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Nieautoryzowany", "modelProvider.buyQuota": "Kup limit", "modelProvider.callTimes": "Czasy wywołań", + "modelProvider.card.aiCreditsInUse": "Używane AI credits", + "modelProvider.card.aiCreditsOption": "AI credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "Wymagany klucz API", + "modelProvider.card.apiKeyUnavailableFallback": "API Key niedostępny, używane są AI credits", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Sprawdź konfigurację klucza API, aby przełączyć z powrotem", "modelProvider.card.buyQuota": "Kup limit", "modelProvider.card.callTimes": "Czasy wywołań", + "modelProvider.card.creditsExhaustedDescription": "Proszę ulepszyć swój plan lub skonfigurować klucz API", + "modelProvider.card.creditsExhaustedFallback": "AI credits wyczerpane, używany jest API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Ulepsz swój plan, aby przywrócić priorytet AI credits.", + "modelProvider.card.creditsExhaustedMessage": "AI credits zostały wyczerpane", "modelProvider.card.modelAPI": "Modele {{modelName}} używają klucza API.", "modelProvider.card.modelNotSupported": "Modele {{modelName}} nie są zainstalowane.", "modelProvider.card.modelSupported": "Modele {{modelName}} używają tego limitu.", + "modelProvider.card.noApiKeysDescription": "Dodaj klucz API, aby zacząć korzystać z własnych poświadczeń modelu.", + "modelProvider.card.noApiKeysFallback": "Brak kluczy API, używane są AI credits", + "modelProvider.card.noApiKeysTitle": "Nie skonfigurowano jeszcze kluczy API", + "modelProvider.card.noAvailableUsage": "Brak dostępnego użycia", "modelProvider.card.onTrial": "Na próbę", "modelProvider.card.paid": "Płatny", "modelProvider.card.priorityUse": "Używanie z priorytetem", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Usuń klucz API", "modelProvider.card.tip": "Kredyty wiadomości obsługują modele od {{modelNames}}. Priorytet zostanie nadany płatnemu limitowi. Darmowy limit zostanie użyty po wyczerpaniu płatnego limitu.", "modelProvider.card.tokens": "Tokeny", + "modelProvider.card.unavailable": "Niedostępne", + "modelProvider.card.upgradePlan": "ulepsz swój plan", + "modelProvider.card.usageLabel": "Użycie", + "modelProvider.card.usagePriority": "Priorytet użycia", + "modelProvider.card.usagePriorityTip": "Ustaw, który zasób ma być używany jako pierwszy przy uruchamianiu modeli.", "modelProvider.collapse": "Zwiń", "modelProvider.config": "Konfiguracja", "modelProvider.configLoadBalancing": "Równoważenie obciążenia konfiguracji", @@ -387,9 +406,11 @@ "modelProvider.model": "Model", "modelProvider.modelAndParameters": "Model i parametry", "modelProvider.modelHasBeenDeprecated": "Ten model jest przestarzały", + "modelProvider.modelSettings": "Ustawienia modelu", "modelProvider.models": "Modele", "modelProvider.modelsNum": "{{num}} Modele", "modelProvider.noModelFound": "Nie znaleziono modelu dla {{model}}", + "modelProvider.noneConfigured": "Skonfiguruj domyślny model systemowy, aby uruchamiać aplikacje", "modelProvider.notConfigured": "Systemowy model nie został jeszcze w pełni skonfigurowany, co może skutkować niedostępnością niektórych funkcji.", "modelProvider.parameters": "PARAMETRY", "modelProvider.parametersInvalidRemoved": "Niektóre parametry są nieprawidłowe i zostały usunięte.", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Reset {{date}}", "modelProvider.searchModel": "Model wyszukiwania", "modelProvider.selectModel": "Wybierz swój model", + "modelProvider.selector.aiCredits": "AI credits", + "modelProvider.selector.apiKeyUnavailable": "API Key niedostępny", + "modelProvider.selector.apiKeyUnavailableTip": "Klucz API został usunięty. Proszę skonfigurować nowy klucz API.", + "modelProvider.selector.configure": "Konfiguruj", + "modelProvider.selector.configureRequired": "Wymagana konfiguracja", + "modelProvider.selector.creditsExhausted": "Kredyty wyczerpane", + "modelProvider.selector.creditsExhaustedTip": "Twoje AI credits zostały wyczerpane. Proszę ulepszyć swój plan lub dodać klucz API.", + "modelProvider.selector.disabled": "Wyłączony", + "modelProvider.selector.discoverMoreInMarketplace": "Odkryj więcej w Marketplace", "modelProvider.selector.emptySetting": "Przejdź do ustawień, aby skonfigurować", "modelProvider.selector.emptyTip": "Brak dostępnych modeli", + "modelProvider.selector.fromMarketplace": "Z Marketplace", + "modelProvider.selector.incompatible": "Niekompatybilny", + "modelProvider.selector.incompatibleTip": "Ten model nie jest dostępny w bieżącej wersji. Proszę wybrać inny dostępny model.", + "modelProvider.selector.install": "Zainstaluj", + "modelProvider.selector.modelProviderSettings": "Ustawienia dostawcy modeli", + "modelProvider.selector.noProviderConfigured": "Nie skonfigurowano dostawcy modeli", + "modelProvider.selector.noProviderConfiguredDesc": "Przeglądaj Marketplace, aby zainstalować dostawcę, lub skonfiguruj dostawców w ustawieniach.", + "modelProvider.selector.onlyCompatibleModelsShown": "Wyświetlane są tylko kompatybilne modele", "modelProvider.selector.rerankTip": "Proszę skonfigurować model ponownego rankingu", "modelProvider.selector.tip": "Ten model został usunięty. Proszę dodać model lub wybrać inny model.", "modelProvider.setupModelFirst": "Proszę najpierw skonfigurować swój model", diff --git a/web/i18n/pl-PL/plugin.json b/web/i18n/pl-PL/plugin.json index abe94610dc..88ac549968 100644 --- a/web/i18n/pl-PL/plugin.json +++ b/web/i18n/pl-PL/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Usuń wtyczkę", "action.deleteContentLeft": "Czy chcesz usunąć", "action.deleteContentRight": "wtyczka?", + "action.deleteSuccess": "Wtyczka została pomyślnie usunięta", "action.pluginInfo": "Informacje o wtyczce", "action.usedInApps": "Ta wtyczka jest używana w aplikacjach {{num}}.", "allCategories": "Wszystkie kategorie", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Instalować", "detailPanel.operation.remove": "Usunąć", "detailPanel.operation.update": "Aktualizacja", + "detailPanel.operation.updateTooltip": "Zaktualizuj, aby uzyskać dostęp do najnowszych modeli.", "detailPanel.operation.viewDetail": "Pokaż szczegóły", "detailPanel.serviceOk": "Serwis OK", "detailPanel.strategyNum": "{{num}} {{strategy}} ZAWARTE", @@ -231,12 +233,18 @@ "source.local": "Lokalny plik pakietu", "source.marketplace": "Rynek", "task.clearAll": "Wyczyść wszystko", + "task.errorMsg.github": "Nie udało się automatycznie zainstalować tej wtyczki.\nProszę zainstalować ją z GitHub.", + "task.errorMsg.marketplace": "Nie udało się automatycznie zainstalować tej wtyczki.\nProszę zainstalować ją z Marketplace.", + "task.errorMsg.unknown": "Nie udało się zainstalować tej wtyczki.\nNie można zidentyfikować źródła wtyczki.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "Nie udało się zainstalować wtyczek {{errorLength}}, kliknij, aby wyświetlić", + "task.installFromGithub": "Zainstaluj z GitHub", + "task.installFromMarketplace": "Zainstaluj z Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "Nie udało się zainstalować wtyczek {{errorLength}}", "task.installing": "Instalowanie wtyczek.", + "task.installingHint": "Instalowanie... To może potrwać kilka minut.", "task.installingWithError": "Instalacja wtyczek {{installingLength}}, {{successLength}} powodzenie, {{errorLength}} niepowodzenie", "task.installingWithSuccess": "Instalacja wtyczek {{installingLength}}, {{successLength}} powodzenie.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/pl-PL/workflow.json b/web/i18n/pl-PL/workflow.json index 52514fe6b7..8aad8e0b71 100644 --- a/web/i18n/pl-PL/workflow.json +++ b/web/i18n/pl-PL/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "aktualizowanie przepływu pracy", "error.startNodeRequired": "Najpierw dodaj węzeł początkowy przed {{operation}}", "errorMsg.authRequired": "Wymagana autoryzacja", + "errorMsg.configureModel": "Skonfiguruj model", "errorMsg.fieldRequired": "{{field}} jest wymagane", "errorMsg.fields.code": "Kod", "errorMsg.fields.model": "Model", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Zmienna wizji", "errorMsg.invalidJson": "{{field}} jest nieprawidłowym JSON-em", "errorMsg.invalidVariable": "Nieprawidłowa zmienna", + "errorMsg.modelPluginNotInstalled": "Nieprawidłowa zmienna. Skonfiguruj model, aby włączyć tę zmienną.", "errorMsg.noValidTool": "{{field}} nie wybrano prawidłowego narzędzia", "errorMsg.rerankModelRequired": "Przed włączeniem Rerank Model upewnij się, że model został pomyślnie skonfigurowany w ustawieniach.", "errorMsg.startNodeRequired": "Najpierw dodaj węzeł początkowy przed {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Rozmiar okna", "nodes.common.outputVars": "Zmienne wyjściowe", "nodes.common.pluginNotInstalled": "Wtyczka nie jest zainstalowana", + "nodes.common.pluginsNotInstalled": "{{count}} wtyczek niezainstalowanych", "nodes.common.retry.maxRetries": "Maksymalna liczba ponownych prób", "nodes.common.retry.ms": "Ms", "nodes.common.retry.retries": "{{num}} Ponownych prób", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Kawałki", "nodes.knowledgeBase.chunksInputTip": "Zmienna wejściowa węzła bazy wiedzy to Chunks. Typ zmiennej to obiekt z określonym schematem JSON, który musi być zgodny z wybraną strukturą chunk.", "nodes.knowledgeBase.chunksVariableIsRequired": "Wymagana jest zmienna Chunks", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "Klucz API niedostępny", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Kredyty wyczerpane", + "nodes.knowledgeBase.embeddingModelIncompatible": "Niekompatybilny", "nodes.knowledgeBase.embeddingModelIsInvalid": "Model osadzania jest nieprawidłowy", "nodes.knowledgeBase.embeddingModelIsRequired": "Wymagany jest model osadzania", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Model embeddingu nie jest skonfigurowany", "nodes.knowledgeBase.indexMethodIsRequired": "Metoda indeksowa jest wymagana", + "nodes.knowledgeBase.notConfigured": "Nieskonfigurowany", "nodes.knowledgeBase.rerankingModelIsInvalid": "Model ponownego rankingowania jest nieprawidłowy", "nodes.knowledgeBase.rerankingModelIsRequired": "Wymagany jest model ponownego rankingu", "nodes.knowledgeBase.retrievalSettingIsRequired": "Wymagane jest ustawienie pobierania", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Obsługuje tylko Jinja2", "nodes.templateTransform.inputVars": "Zmienne wejściowe", "nodes.templateTransform.outputVars.output": "Przekształcona treść", + "nodes.tool.authorizationRequired": "Wymagana autoryzacja", "nodes.tool.authorize": "Autoryzuj", "nodes.tool.inputVars": "Zmienne wejściowe", "nodes.tool.insertPlaceholder1": "Wpisz lub naciśnij", @@ -1062,10 +1071,12 @@ "panel.change": "Zmień", "panel.changeBlock": "Zmień węzeł", "panel.checklist": "Lista kontrolna", + "panel.checklistDescription": "Rozwiąż poniższe problemy przed opublikowaniem", "panel.checklistResolved": "Wszystkie problemy zostały rozwiązane", "panel.checklistTip": "Upewnij się, że wszystkie problemy zostały rozwiązane przed opublikowaniem", "panel.createdBy": "Stworzone przez ", "panel.goTo": "Idź do", + "panel.goToFix": "Przejdź do naprawy", "panel.helpLink": "Pomoc", "panel.maximize": "Maksymalizuj płótno", "panel.minimize": "Wyjdź z trybu pełnoekranowego", diff --git a/web/i18n/pt-BR/app-debug.json b/web/i18n/pt-BR/app-debug.json index 5b0b4c9969..3b1a0da424 100644 --- a/web/i18n/pt-BR/app-debug.json +++ b/web/i18n/pt-BR/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "EXECUTAR", "inputs.title": "Depuração e Visualização", "inputs.userInputField": "Campo de Entrada do Usuário", + "manageModels": "Gerenciar modelos", "modelConfig.modeType.chat": "Chat", "modelConfig.modeType.completion": "Completar", "modelConfig.model": "Modelo", "modelConfig.setTone": "Definir tom das respostas", "modelConfig.title": "Modelo e Parâmetros", + "noModelProviderConfigured": "Nenhum provedor de modelo configurado", + "noModelProviderConfiguredTip": "Instale ou configure um provedor de modelo para começar.", + "noModelSelected": "Nenhum modelo selecionado", + "noModelSelectedTip": "Configure um modelo acima para continuar.", "noResult": "A saída será exibida aqui.", "notSetAPIKey.description": "A chave do provedor LLM não foi definida e precisa ser definida antes da depuração.", "notSetAPIKey.settingBtn": "Ir para configurações", diff --git a/web/i18n/pt-BR/common.json b/web/i18n/pt-BR/common.json index 0f2106b1eb..6840bb964b 100644 --- a/web/i18n/pt-BR/common.json +++ b/web/i18n/pt-BR/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Não autorizado", "modelProvider.buyQuota": "Comprar Quota", "modelProvider.callTimes": "Chamadas", + "modelProvider.card.aiCreditsInUse": "AI credits em uso", + "modelProvider.card.aiCreditsOption": "AI credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "API Key necessária", + "modelProvider.card.apiKeyUnavailableFallback": "API Key indisponível, usando AI credits", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Verifique a configuração da sua API Key para voltar a usá-la", "modelProvider.card.buyQuota": "Comprar Quota", "modelProvider.card.callTimes": "Chamadas", + "modelProvider.card.creditsExhaustedDescription": "Por favor, atualize seu plano ou configure uma API Key", + "modelProvider.card.creditsExhaustedFallback": "AI credits esgotados, usando API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Atualize seu plano para retomar a prioridade de AI credits.", + "modelProvider.card.creditsExhaustedMessage": "AI credits foram esgotados", "modelProvider.card.modelAPI": "Os modelos {{modelName}} estão usando a Chave API.", "modelProvider.card.modelNotSupported": "Os modelos {{modelName}} não estão instalados.", "modelProvider.card.modelSupported": "Os modelos {{modelName}} estão usando esta cota.", + "modelProvider.card.noApiKeysDescription": "Adicione uma API Key para usar suas próprias credenciais de modelo.", + "modelProvider.card.noApiKeysFallback": "Sem API Keys, usando AI credits", + "modelProvider.card.noApiKeysTitle": "Nenhuma API Key configurada", + "modelProvider.card.noAvailableUsage": "Sem uso disponível", "modelProvider.card.onTrial": "Em Teste", "modelProvider.card.paid": "Pago", "modelProvider.card.priorityUse": "Uso prioritário", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Remover Chave da API", "modelProvider.card.tip": "Créditos de mensagens suportam modelos de {{modelNames}}. A prioridade será dada à quota paga. A quota gratuita será usada após a quota paga ser esgotada.", "modelProvider.card.tokens": "Tokens", + "modelProvider.card.unavailable": "Indisponível", + "modelProvider.card.upgradePlan": "atualize seu plano", + "modelProvider.card.usageLabel": "Uso", + "modelProvider.card.usagePriority": "Prioridade de uso", + "modelProvider.card.usagePriorityTip": "Defina qual recurso usar primeiro ao executar modelos.", "modelProvider.collapse": "Recolher", "modelProvider.config": "Configuração", "modelProvider.configLoadBalancing": "Balanceamento de carga de configuração", @@ -387,9 +406,11 @@ "modelProvider.model": "Modelo", "modelProvider.modelAndParameters": "Modelo e Parâmetros", "modelProvider.modelHasBeenDeprecated": "Este modelo foi preterido", + "modelProvider.modelSettings": "Configurações de modelo", "modelProvider.models": "Modelos", "modelProvider.modelsNum": "{{num}} Modelos", "modelProvider.noModelFound": "Nenhum modelo encontrado para {{model}}", + "modelProvider.noneConfigured": "Configure um modelo de sistema padrão para executar aplicações", "modelProvider.notConfigured": "O modelo do sistema ainda não foi totalmente configurado e algumas funções podem estar indisponíveis.", "modelProvider.parameters": "PARÂMETROS", "modelProvider.parametersInvalidRemoved": "Alguns parâmetros são inválidos e foram removidos", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Redefinir em {{date}}", "modelProvider.searchModel": "Modelo de pesquisa", "modelProvider.selectModel": "Selecione seu modelo", + "modelProvider.selector.aiCredits": "AI credits", + "modelProvider.selector.apiKeyUnavailable": "API Key indisponível", + "modelProvider.selector.apiKeyUnavailableTip": "A API Key foi removida. Por favor, configure uma nova API Key.", + "modelProvider.selector.configure": "Configurar", + "modelProvider.selector.configureRequired": "Configuração necessária", + "modelProvider.selector.creditsExhausted": "Créditos esgotados", + "modelProvider.selector.creditsExhaustedTip": "Seus AI credits foram esgotados. Por favor, atualize seu plano ou adicione uma API Key.", + "modelProvider.selector.disabled": "Desativado", + "modelProvider.selector.discoverMoreInMarketplace": "Descubra mais no Marketplace", "modelProvider.selector.emptySetting": "Por favor, vá para configurações para configurar", "modelProvider.selector.emptyTip": "Nenhum modelo disponível", + "modelProvider.selector.fromMarketplace": "Do Marketplace", + "modelProvider.selector.incompatible": "Incompatível", + "modelProvider.selector.incompatibleTip": "Este modelo não está disponível na versão atual. Por favor, selecione outro modelo disponível.", + "modelProvider.selector.install": "Instalar", + "modelProvider.selector.modelProviderSettings": "Configurações do provedor de modelo", + "modelProvider.selector.noProviderConfigured": "Nenhum provedor de modelo configurado", + "modelProvider.selector.noProviderConfiguredDesc": "Navegue pelo Marketplace para instalar um, ou configure provedores nas configurações.", + "modelProvider.selector.onlyCompatibleModelsShown": "Apenas modelos compatíveis são exibidos", "modelProvider.selector.rerankTip": "Por favor, configure o modelo de reordenação", "modelProvider.selector.tip": "Este modelo foi removido. Adicione um modelo ou selecione outro modelo.", "modelProvider.setupModelFirst": "Por favor, configure seu modelo primeiro", diff --git a/web/i18n/pt-BR/plugin.json b/web/i18n/pt-BR/plugin.json index c9297b5840..d5d218c8b6 100644 --- a/web/i18n/pt-BR/plugin.json +++ b/web/i18n/pt-BR/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Remover plugin", "action.deleteContentLeft": "Gostaria de remover", "action.deleteContentRight": "plugin?", + "action.deleteSuccess": "Plugin removido com sucesso", "action.pluginInfo": "Informações do plugin", "action.usedInApps": "Este plugin está sendo usado em aplicativos {{num}}.", "allCategories": "Todas as categorias", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Instalar", "detailPanel.operation.remove": "Retirar", "detailPanel.operation.update": "Atualização", + "detailPanel.operation.updateTooltip": "Atualize para acessar os modelos mais recentes.", "detailPanel.operation.viewDetail": "Ver detalhes", "detailPanel.serviceOk": "Serviço OK", "detailPanel.strategyNum": "{{num}} {{strategy}} INCLUSO", @@ -231,12 +233,18 @@ "source.local": "Arquivo de pacote local", "source.marketplace": "Mercado", "task.clearAll": "Apagar tudo", + "task.errorMsg.github": "Este plugin não pôde ser instalado automaticamente.\nPor favor, instale-o pelo GitHub.", + "task.errorMsg.marketplace": "Este plugin não pôde ser instalado automaticamente.\nPor favor, instale-o pelo Marketplace.", + "task.errorMsg.unknown": "Este plugin não pôde ser instalado.\nNão foi possível identificar a origem do plugin.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} plugins falha ao instalar, clique para ver", + "task.installFromGithub": "Instalar pelo GitHub", + "task.installFromMarketplace": "Instalar pelo Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "Falha na instalação dos plug-ins {{errorLength}}", "task.installing": "Instalando plugins.", + "task.installingHint": "Instalando... Isso pode levar alguns minutos.", "task.installingWithError": "Instalando plug-ins {{installingLength}}, {{successLength}} sucesso, {{errorLength}} falhou", "task.installingWithSuccess": "Instalando plugins {{installingLength}}, {{successLength}} sucesso.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/pt-BR/workflow.json b/web/i18n/pt-BR/workflow.json index 3edab361bb..dc986df82c 100644 --- a/web/i18n/pt-BR/workflow.json +++ b/web/i18n/pt-BR/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "atualizando fluxo de trabalho", "error.startNodeRequired": "Por favor, adicione um nó inicial antes de {{operation}}", "errorMsg.authRequired": "Autorização é necessária", + "errorMsg.configureModel": "Configure um modelo", "errorMsg.fieldRequired": "{{field}} é obrigatório", "errorMsg.fields.code": "Código", "errorMsg.fields.model": "Modelo", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Variável de visão", "errorMsg.invalidJson": "{{field}} é um JSON inválido", "errorMsg.invalidVariable": "Variável inválida", + "errorMsg.modelPluginNotInstalled": "Variável inválida. Configure um modelo para habilitar esta variável.", "errorMsg.noValidTool": "{{field}} nenhuma ferramenta válida selecionada", "errorMsg.rerankModelRequired": "Antes de ativar o modelo de reclassificação, confirme se o modelo foi configurado com sucesso nas configurações.", "errorMsg.startNodeRequired": "Por favor, adicione um nó inicial antes de {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Tamanho da janela", "nodes.common.outputVars": "Variáveis de saída", "nodes.common.pluginNotInstalled": "O plugin não está instalado", + "nodes.common.pluginsNotInstalled": "{{count}} plugins não instalados", "nodes.common.retry.maxRetries": "Máximo de tentativas", "nodes.common.retry.ms": "ms", "nodes.common.retry.retries": "{{num}} Tentativas", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Pedaços", "nodes.knowledgeBase.chunksInputTip": "A variável de entrada do nó da base de conhecimento é Chunks. O tipo da variável é um objeto com um esquema JSON específico que deve ser consistente com a estrutura de chunk selecionada.", "nodes.knowledgeBase.chunksVariableIsRequired": "A variável 'chunks' é obrigatória", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key indisponível", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Créditos esgotados", + "nodes.knowledgeBase.embeddingModelIncompatible": "Incompatível", "nodes.knowledgeBase.embeddingModelIsInvalid": "O modelo de incorporação é inválido", "nodes.knowledgeBase.embeddingModelIsRequired": "Modelo de incorporação é necessário", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Modelo de embedding não configurado", "nodes.knowledgeBase.indexMethodIsRequired": "O método de índice é necessário", + "nodes.knowledgeBase.notConfigured": "Não configurado", "nodes.knowledgeBase.rerankingModelIsInvalid": "O modelo de reclassificação é inválido", "nodes.knowledgeBase.rerankingModelIsRequired": "Um modelo de reclassificação é necessário", "nodes.knowledgeBase.retrievalSettingIsRequired": "A configuração de recuperação é necessária", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Suporta apenas Jinja2", "nodes.templateTransform.inputVars": "Variáveis de entrada", "nodes.templateTransform.outputVars.output": "Conteúdo transformado", + "nodes.tool.authorizationRequired": "Autorização necessária", "nodes.tool.authorize": "Autorizar", "nodes.tool.inputVars": "Variáveis de entrada", "nodes.tool.insertPlaceholder1": "Digite ou pressione", @@ -1062,10 +1071,12 @@ "panel.change": "Mudar", "panel.changeBlock": "Mudar Nó", "panel.checklist": "Lista de verificação", + "panel.checklistDescription": "Resolva os seguintes problemas antes de publicar", "panel.checklistResolved": "Todos os problemas foram resolvidos", "panel.checklistTip": "Certifique-se de que todos os problemas foram resolvidos antes de publicar", "panel.createdBy": "Criado por ", "panel.goTo": "Ir para", + "panel.goToFix": "Ir para correção", "panel.helpLink": "Ajuda", "panel.maximize": "Maximize Canvas", "panel.minimize": "Sair do Modo Tela Cheia", diff --git a/web/i18n/ro-RO/app-debug.json b/web/i18n/ro-RO/app-debug.json index 6245ca680d..d2b8294804 100644 --- a/web/i18n/ro-RO/app-debug.json +++ b/web/i18n/ro-RO/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "RULARE", "inputs.title": "Depanare și previzualizare", "inputs.userInputField": "Câmp de intrare utilizator", + "manageModels": "Gestionare modele", "modelConfig.modeType.chat": "Chat", "modelConfig.modeType.completion": "Completare", "modelConfig.model": "Model", "modelConfig.setTone": "Setați tonul răspunsurilor", "modelConfig.title": "Model și Parametri", + "noModelProviderConfigured": "Niciun furnizor de modele configurat", + "noModelProviderConfiguredTip": "Instalați sau configurați un furnizor de modele pentru a începe.", + "noModelSelected": "Niciun model selectat", + "noModelSelectedTip": "configurați un model mai sus pentru a continua.", "noResult": "Ieșirea va fi afișată aici.", "notSetAPIKey.description": "Cheia furnizorului LLM nu a fost setată și trebuie să fie setată înainte de depanare.", "notSetAPIKey.settingBtn": "Du-te la setări", diff --git a/web/i18n/ro-RO/common.json b/web/i18n/ro-RO/common.json index 8fdb09b3c1..306439768b 100644 --- a/web/i18n/ro-RO/common.json +++ b/web/i18n/ro-RO/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Neautorizat", "modelProvider.buyQuota": "Cumpără cotă", "modelProvider.callTimes": "Apeluri", + "modelProvider.card.aiCreditsInUse": "AI credits în uz", + "modelProvider.card.aiCreditsOption": "AI credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "API key necesar", + "modelProvider.card.apiKeyUnavailableFallback": "API Key indisponibil, se utilizează AI credits", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Verificați configurația API key pentru a reveni", "modelProvider.card.buyQuota": "Cumpără cotă", "modelProvider.card.callTimes": "Apeluri", + "modelProvider.card.creditsExhaustedDescription": "Vă rugăm să faceți upgrade la plan sau să configurați un API key", + "modelProvider.card.creditsExhaustedFallback": "AI credits epuizate, se utilizează API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Faceți upgrade la plan pentru a relua prioritatea AI credits.", + "modelProvider.card.creditsExhaustedMessage": "AI credits au fost epuizate", "modelProvider.card.modelAPI": "Modelele {{modelName}} folosesc cheia API.", "modelProvider.card.modelNotSupported": "Modelele {{modelName}} nu sunt instalate.", "modelProvider.card.modelSupported": "Modelele {{modelName}} folosesc această cotă.", + "modelProvider.card.noApiKeysDescription": "Adăugați un API key pentru a începe să utilizați propriile credențiale de model.", + "modelProvider.card.noApiKeysFallback": "Fără API key-uri, se utilizează AI credits", + "modelProvider.card.noApiKeysTitle": "Niciun API key configurat încă", + "modelProvider.card.noAvailableUsage": "Nicio utilizare disponibilă", "modelProvider.card.onTrial": "În probă", "modelProvider.card.paid": "Plătit", "modelProvider.card.priorityUse": "Utilizare prioritară", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Elimină cheia API", "modelProvider.card.tip": "Creditele de mesaje acceptă modele de la {{modelNames}}. Prioritate va fi acordată cotei plătite. Cota gratuită va fi utilizată după epuizarea cotei plătite.", "modelProvider.card.tokens": "Jetoane", + "modelProvider.card.unavailable": "Indisponibil", + "modelProvider.card.upgradePlan": "faceți upgrade la plan", + "modelProvider.card.usageLabel": "Utilizare", + "modelProvider.card.usagePriority": "Prioritate de utilizare", + "modelProvider.card.usagePriorityTip": "Setați ce resursă să fie utilizată prima la rularea modelelor.", "modelProvider.collapse": "Restrânge", "modelProvider.config": "Configurare", "modelProvider.configLoadBalancing": "Echilibrarea încărcării de configurare", @@ -387,9 +406,11 @@ "modelProvider.model": "Model", "modelProvider.modelAndParameters": "Model și parametri", "modelProvider.modelHasBeenDeprecated": "Acest model a fost depreciat", + "modelProvider.modelSettings": "Setări model", "modelProvider.models": "Modele", "modelProvider.modelsNum": "{{num}} Modele", "modelProvider.noModelFound": "Nu a fost găsit niciun model pentru {{model}}", + "modelProvider.noneConfigured": "Configurați un model de sistem implicit pentru a rula aplicații", "modelProvider.notConfigured": "Modelul de sistem nu a fost încă configurat complet, iar unele funcții pot fi indisponibile.", "modelProvider.parameters": "PARAMETRI", "modelProvider.parametersInvalidRemoved": "Unele parametrii sunt invalizi și au fost eliminați.", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Resetare la {{date}}", "modelProvider.searchModel": "Model de căutare", "modelProvider.selectModel": "Selectați modelul dvs.", + "modelProvider.selector.aiCredits": "AI credits", + "modelProvider.selector.apiKeyUnavailable": "API Key indisponibil", + "modelProvider.selector.apiKeyUnavailableTip": "API key-ul a fost eliminat. Vă rugăm să configurați un nou API key.", + "modelProvider.selector.configure": "Configurare", + "modelProvider.selector.configureRequired": "Configurare necesară", + "modelProvider.selector.creditsExhausted": "Credite epuizate", + "modelProvider.selector.creditsExhaustedTip": "AI credits au fost epuizate. Vă rugăm să faceți upgrade la plan sau să adăugați un API key.", + "modelProvider.selector.disabled": "Dezactivat", + "modelProvider.selector.discoverMoreInMarketplace": "Descoperiți mai multe în Marketplace", "modelProvider.selector.emptySetting": "Vă rugăm să mergeți la setări pentru a configura", "modelProvider.selector.emptyTip": "Nu există modele disponibile", + "modelProvider.selector.fromMarketplace": "Din Marketplace", + "modelProvider.selector.incompatible": "Incompatibil", + "modelProvider.selector.incompatibleTip": "Acest model nu este disponibil în versiunea curentă. Vă rugăm să selectați alt model disponibil.", + "modelProvider.selector.install": "Instalare", + "modelProvider.selector.modelProviderSettings": "Setări furnizor de modele", + "modelProvider.selector.noProviderConfigured": "Niciun furnizor de modele configurat", + "modelProvider.selector.noProviderConfiguredDesc": "Răsfoiți Marketplace pentru a instala unul sau configurați furnizorii în setări.", + "modelProvider.selector.onlyCompatibleModelsShown": "Sunt afișate doar modelele compatibile", "modelProvider.selector.rerankTip": "Vă rugăm să configurați modelul de reordonare", "modelProvider.selector.tip": "Acest model a fost eliminat. Vă rugăm să adăugați un model sau să selectați un alt model.", "modelProvider.setupModelFirst": "Vă rugăm să configurați mai întâi modelul", diff --git a/web/i18n/ro-RO/plugin.json b/web/i18n/ro-RO/plugin.json index c0295b64d6..15523b6c62 100644 --- a/web/i18n/ro-RO/plugin.json +++ b/web/i18n/ro-RO/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Eliminați pluginul", "action.deleteContentLeft": "Doriți să eliminați", "action.deleteContentRight": "plugin?", + "action.deleteSuccess": "Plugin eliminat cu succes", "action.pluginInfo": "Informații despre plugin", "action.usedInApps": "Acest plugin este folosit în aplicațiile {{num}}.", "allCategories": "Toate categoriile", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Instala", "detailPanel.operation.remove": "Depărta", "detailPanel.operation.update": "Actualiza", + "detailPanel.operation.updateTooltip": "Actualizați pentru a accesa cele mai recente modele.", "detailPanel.operation.viewDetail": "Vezi detalii", "detailPanel.serviceOk": "Serviciu OK", "detailPanel.strategyNum": "{{num}} {{strategy}} INCLUS", @@ -231,12 +233,18 @@ "source.local": "Fișier pachet local", "source.marketplace": "Târg", "task.clearAll": "Ștergeți tot", + "task.errorMsg.github": "Acest plugin nu a putut fi instalat automat.\nVă rugăm să îl instalați de pe GitHub.", + "task.errorMsg.marketplace": "Acest plugin nu a putut fi instalat automat.\nVă rugăm să îl instalați din Marketplace.", + "task.errorMsg.unknown": "Acest plugin nu a putut fi instalat.\nSursa pluginului nu a putut fi identificată.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} plugin-urile nu s-au instalat, faceți clic pentru a vizualiza", + "task.installFromGithub": "Instalare de pe GitHub", + "task.installFromMarketplace": "Instalare din Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} plugin-urile nu s-au instalat", "task.installing": "Se instalează pluginuri.", + "task.installingHint": "Se instalează... Aceasta poate dura câteva minute.", "task.installingWithError": "Instalarea pluginurilor {{installingLength}}, {{successLength}} succes, {{errorLength}} eșuat", "task.installingWithSuccess": "Instalarea pluginurilor {{installingLength}}, {{successLength}} succes.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/ro-RO/workflow.json b/web/i18n/ro-RO/workflow.json index a68d477980..e34ba42007 100644 --- a/web/i18n/ro-RO/workflow.json +++ b/web/i18n/ro-RO/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "actualizarea fluxului de lucru", "error.startNodeRequired": "Vă rugăm să adăugați mai întâi un nod de pornire înainte de {{operation}}", "errorMsg.authRequired": "Autorizarea este necesară", + "errorMsg.configureModel": "Configurați un model", "errorMsg.fieldRequired": "{{field}} este obligatoriu", "errorMsg.fields.code": "Cod", "errorMsg.fields.model": "Model", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Vizibilitate variabilă", "errorMsg.invalidJson": "{{field}} este un JSON invalid", "errorMsg.invalidVariable": "Variabilă invalidă", + "errorMsg.modelPluginNotInstalled": "Variabilă invalidă. Configurați un model pentru a activa această variabilă.", "errorMsg.noValidTool": "{{field}} nu a fost selectat niciun instrument valid", "errorMsg.rerankModelRequired": "Înainte de a activa modelul de reclasificare, vă rugăm să confirmați că modelul a fost configurat cu succes în setări.", "errorMsg.startNodeRequired": "Vă rugăm să adăugați mai întâi un nod de pornire înainte de {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Dimensiunea ferestrei", "nodes.common.outputVars": "Variabile de ieșire", "nodes.common.pluginNotInstalled": "Pluginul nu este instalat", + "nodes.common.pluginsNotInstalled": "{{count}} pluginuri neinstalate", "nodes.common.retry.maxRetries": "numărul maxim de încercări", "nodes.common.retry.ms": "Ms", "nodes.common.retry.retries": "{{num}} Încercări", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Bucăți", "nodes.knowledgeBase.chunksInputTip": "Variabila de intrare a nodului bazei de cunoștințe este Chunks. Tipul variabilei este un obiect cu un Șchema JSON specific care trebuie să fie coerent cu structura de chunk selectată.", "nodes.knowledgeBase.chunksVariableIsRequired": "Variabila Chunks este obligatorie", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API key indisponibil", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Credite epuizate", + "nodes.knowledgeBase.embeddingModelIncompatible": "Incompatibil", "nodes.knowledgeBase.embeddingModelIsInvalid": "Modelul de încorporare este invalid", "nodes.knowledgeBase.embeddingModelIsRequired": "Este necesar un model de încorporare", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Modelul de embedding nu este configurat", "nodes.knowledgeBase.indexMethodIsRequired": "Este necesară metoda indexului", + "nodes.knowledgeBase.notConfigured": "Neconfigurat", "nodes.knowledgeBase.rerankingModelIsInvalid": "Modelul de reordonare este invalid", "nodes.knowledgeBase.rerankingModelIsRequired": "Este necesar un model de reordonare", "nodes.knowledgeBase.retrievalSettingIsRequired": "Setarea de recuperare este necesară", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Suportă doar Jinja2", "nodes.templateTransform.inputVars": "Variabile de intrare", "nodes.templateTransform.outputVars.output": "Conținut transformat", + "nodes.tool.authorizationRequired": "Autorizare necesară", "nodes.tool.authorize": "Autorizați", "nodes.tool.inputVars": "Variabile de intrare", "nodes.tool.insertPlaceholder1": "Scrieți sau apăsați", @@ -1062,10 +1071,12 @@ "panel.change": "Schimbă", "panel.changeBlock": "Schimbă nodul", "panel.checklist": "Lista de verificare", + "panel.checklistDescription": "Rezolvați următoarele probleme înainte de publicare", "panel.checklistResolved": "Toate problemele au fost rezolvate", "panel.checklistTip": "Asigurați-vă că toate problemele sunt rezolvate înainte de publicare", "panel.createdBy": "Creat de ", "panel.goTo": "Du-te la", + "panel.goToFix": "Mergi la remediere", "panel.helpLink": "Ajutor", "panel.maximize": "Maximize Canvas", "panel.minimize": "Iesi din modul pe tot ecranul", diff --git a/web/i18n/ru-RU/app-debug.json b/web/i18n/ru-RU/app-debug.json index fdc797da2f..514958101d 100644 --- a/web/i18n/ru-RU/app-debug.json +++ b/web/i18n/ru-RU/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "ЗАПУСТИТЬ", "inputs.title": "Отладка и предварительный просмотр", "inputs.userInputField": "Поле пользовательского ввода", + "manageModels": "Управление моделями", "modelConfig.modeType.chat": "Чат", "modelConfig.modeType.completion": "Завершение", "modelConfig.model": "Модель", "modelConfig.setTone": "Установить тон ответов", "modelConfig.title": "Модель и параметры", + "noModelProviderConfigured": "Поставщик моделей не настроен", + "noModelProviderConfiguredTip": "Установите или настройте поставщика моделей, чтобы начать работу.", + "noModelSelected": "Модель не выбрана", + "noModelSelectedTip": "Настройте модель выше, чтобы продолжить.", "noResult": "Вывод будет отображаться здесь.", "notSetAPIKey.description": "Ключ поставщика LLM не установлен, его необходимо установить перед отладкой.", "notSetAPIKey.settingBtn": "Перейти к настройкам", diff --git a/web/i18n/ru-RU/common.json b/web/i18n/ru-RU/common.json index dc301e4ed9..aec9a69483 100644 --- a/web/i18n/ru-RU/common.json +++ b/web/i18n/ru-RU/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Неавторизованный", "modelProvider.buyQuota": "Купить квоту", "modelProvider.callTimes": "Количество вызовов", + "modelProvider.card.aiCreditsInUse": "Используются AI-кредиты", + "modelProvider.card.aiCreditsOption": "AI-кредиты", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "Требуется API-ключ", + "modelProvider.card.apiKeyUnavailableFallback": "API Key недоступен, используются AI-кредиты", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Проверьте настройки API Key для переключения обратно", "modelProvider.card.buyQuota": "Купить квоту", "modelProvider.card.callTimes": "Количество вызовов", + "modelProvider.card.creditsExhaustedDescription": "Пожалуйста, обновите тарифный план или настройте API-ключ", + "modelProvider.card.creditsExhaustedFallback": "AI-кредиты исчерпаны, используется API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Обновите тарифный план, чтобы возобновить приоритетное использование AI-кредитов.", + "modelProvider.card.creditsExhaustedMessage": "AI-кредиты исчерпаны", "modelProvider.card.modelAPI": "Модели {{modelName}} используют API-ключ.", "modelProvider.card.modelNotSupported": "Модели {{modelName}} не установлены.", "modelProvider.card.modelSupported": "Модели {{modelName}} используют эту квоту.", + "modelProvider.card.noApiKeysDescription": "Добавьте API-ключ, чтобы использовать собственные учётные данные модели.", + "modelProvider.card.noApiKeysFallback": "API-ключи не настроены, используются AI-кредиты", + "modelProvider.card.noApiKeysTitle": "API-ключи ещё не настроены", + "modelProvider.card.noAvailableUsage": "Нет доступного использования", "modelProvider.card.onTrial": "Пробная версия", "modelProvider.card.paid": "Платный", "modelProvider.card.priorityUse": "Приоритетное использование", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Удалить API-ключ", "modelProvider.card.tip": "Кредиты сообщений поддерживают модели от {{modelNames}}. Приоритет будет отдаваться платной квоте. Бесплатная квота будет использоваться после исчерпания платной квоты.", "modelProvider.card.tokens": "Токены", + "modelProvider.card.unavailable": "Недоступно", + "modelProvider.card.upgradePlan": "обновите тарифный план", + "modelProvider.card.usageLabel": "Использование", + "modelProvider.card.usagePriority": "Приоритет использования", + "modelProvider.card.usagePriorityTip": "Укажите, какой ресурс использовать первым при запуске моделей.", "modelProvider.collapse": "Свернуть", "modelProvider.config": "Настройка", "modelProvider.configLoadBalancing": "Настроить балансировку нагрузки", @@ -387,9 +406,11 @@ "modelProvider.model": "Модель", "modelProvider.modelAndParameters": "Модель и параметры", "modelProvider.modelHasBeenDeprecated": "Эта модель устарела", + "modelProvider.modelSettings": "Настройки моделей", "modelProvider.models": "Модели", "modelProvider.modelsNum": "{{num}} Моделей", "modelProvider.noModelFound": "Модель не найдена для {{model}}", + "modelProvider.noneConfigured": "Настройте системную модель по умолчанию для запуска приложений", "modelProvider.notConfigured": "Системная модель еще не полностью настроена, и некоторые функции могут быть недоступны.", "modelProvider.parameters": "ПАРАМЕТРЫ", "modelProvider.parametersInvalidRemoved": "Некоторые параметры недействительны и были удалены", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Сброс {{date}}", "modelProvider.searchModel": "Поиск модели", "modelProvider.selectModel": "Выберите свою модель", + "modelProvider.selector.aiCredits": "AI-кредиты", + "modelProvider.selector.apiKeyUnavailable": "API Key недоступен", + "modelProvider.selector.apiKeyUnavailableTip": "API Key был удалён. Пожалуйста, настройте новый API Key.", + "modelProvider.selector.configure": "Настроить", + "modelProvider.selector.configureRequired": "Требуется настройка", + "modelProvider.selector.creditsExhausted": "Кредиты исчерпаны", + "modelProvider.selector.creditsExhaustedTip": "Ваши AI-кредиты исчерпаны. Обновите тарифный план или добавьте API-ключ.", + "modelProvider.selector.disabled": "Отключено", + "modelProvider.selector.discoverMoreInMarketplace": "Найти больше в Marketplace", "modelProvider.selector.emptySetting": "Пожалуйста, перейдите в настройки для настройки", "modelProvider.selector.emptyTip": "Нет доступных моделей", + "modelProvider.selector.fromMarketplace": "Из Marketplace", + "modelProvider.selector.incompatible": "Несовместимо", + "modelProvider.selector.incompatibleTip": "Эта модель недоступна в текущей версии. Пожалуйста, выберите другую доступную модель.", + "modelProvider.selector.install": "Установить", + "modelProvider.selector.modelProviderSettings": "Настройки поставщика моделей", + "modelProvider.selector.noProviderConfigured": "Поставщик моделей не настроен", + "modelProvider.selector.noProviderConfiguredDesc": "Установите из Marketplace или настройте поставщиков в настройках.", + "modelProvider.selector.onlyCompatibleModelsShown": "Показаны только совместимые модели", "modelProvider.selector.rerankTip": "Пожалуйста, настройте модель повторного ранжирования", "modelProvider.selector.tip": "Эта модель была удалена. Пожалуйста, добавьте модель или выберите другую модель.", "modelProvider.setupModelFirst": "Пожалуйста, сначала настройте свою модель", diff --git a/web/i18n/ru-RU/plugin.json b/web/i18n/ru-RU/plugin.json index 3b10a3a995..52cb6cbd55 100644 --- a/web/i18n/ru-RU/plugin.json +++ b/web/i18n/ru-RU/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Удалить плагин", "action.deleteContentLeft": "Вы хотели бы удалить", "action.deleteContentRight": "Плагин?", + "action.deleteSuccess": "Плагин успешно удалён", "action.pluginInfo": "Информация о плагине", "action.usedInApps": "Этот плагин используется в приложениях {{num}}.", "allCategories": "Все категории", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Устанавливать", "detailPanel.operation.remove": "Убирать", "detailPanel.operation.update": "Обновлять", + "detailPanel.operation.updateTooltip": "Обновите для доступа к последним моделям.", "detailPanel.operation.viewDetail": "Подробнее", "detailPanel.serviceOk": "Услуга ОК", "detailPanel.strategyNum": "{{num}} {{strategy}} ВКЛЮЧЕННЫЙ", @@ -231,12 +233,18 @@ "source.local": "Локальный файл пакета", "source.marketplace": "Рынок", "task.clearAll": "Очистить все", + "task.errorMsg.github": "Не удалось автоматически установить этот плагин.\nПожалуйста, установите его из GitHub.", + "task.errorMsg.marketplace": "Не удалось автоматически установить этот плагин.\nПожалуйста, установите его из Marketplace.", + "task.errorMsg.unknown": "Не удалось установить этот плагин.\nНе удалось определить источник плагина.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "Плагины {{errorLength}} не удалось установить, нажмите для просмотра", + "task.installFromGithub": "Установить из GitHub", + "task.installFromMarketplace": "Установить из Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "плагины {{errorLength}} не удалось установить", "task.installing": "Установка плагинов.", + "task.installingHint": "Установка... Это может занять несколько минут.", "task.installingWithError": "Установка плагинов {{installingLength}}, {{successLength}} успех, {{errorLength}} неудачный", "task.installingWithSuccess": "Установка плагинов {{installingLength}}, {{successLength}} успех.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/ru-RU/workflow.json b/web/i18n/ru-RU/workflow.json index 9265e4e6d6..73cdad253a 100644 --- a/web/i18n/ru-RU/workflow.json +++ b/web/i18n/ru-RU/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "обновление рабочего процесса", "error.startNodeRequired": "Пожалуйста, сначала добавьте начальный узел перед {{operation}}", "errorMsg.authRequired": "Требуется авторизация", + "errorMsg.configureModel": "Настройте модель", "errorMsg.fieldRequired": "{{field}} обязательно для заполнения", "errorMsg.fields.code": "Код", "errorMsg.fields.model": "Модель", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Переменная зрения", "errorMsg.invalidJson": "{{field}} неверный JSON", "errorMsg.invalidVariable": "Неверная переменная", + "errorMsg.modelPluginNotInstalled": "Недопустимая переменная. Настройте модель, чтобы активировать эту переменную.", "errorMsg.noValidTool": "{{field}} не выбран валидный инструмент", "errorMsg.rerankModelRequired": "Перед включением модели повторного ранжирования убедитесь, что модель успешно настроена в настройках.", "errorMsg.startNodeRequired": "Пожалуйста, сначала добавьте начальный узел перед {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Размер окна", "nodes.common.outputVars": "Выходные переменные", "nodes.common.pluginNotInstalled": "Плагин не установлен", + "nodes.common.pluginsNotInstalled": "{{count}} плагинов не установлено", "nodes.common.retry.maxRetries": "максимальное количество повторных попыток", "nodes.common.retry.ms": "госпожа", "nodes.common.retry.retries": "{{num}} Повторных попыток", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Куски", "nodes.knowledgeBase.chunksInputTip": "Входная переменная узла базы знаний - это Чанки. Тип переменной является объектом с определенной схемой JSON, которая должна соответствовать выбранной структуре чанка.", "nodes.knowledgeBase.chunksVariableIsRequired": "Переменная chunks обязательна", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API-ключ недоступен", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Кредиты исчерпаны", + "nodes.knowledgeBase.embeddingModelIncompatible": "Несовместимо", "nodes.knowledgeBase.embeddingModelIsInvalid": "Модель встраивания недействительна", "nodes.knowledgeBase.embeddingModelIsRequired": "Требуется модель встраивания", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Модель эмбеддингов не настроена", "nodes.knowledgeBase.indexMethodIsRequired": "Метод index является обязательным", + "nodes.knowledgeBase.notConfigured": "Не настроено", "nodes.knowledgeBase.rerankingModelIsInvalid": "Модель повторной ранжировки недействительна", "nodes.knowledgeBase.rerankingModelIsRequired": "Требуется модель перераспределения рангов", "nodes.knowledgeBase.retrievalSettingIsRequired": "Настройка извлечения обязательна", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Поддерживает только Jinja2", "nodes.templateTransform.inputVars": "Входные переменные", "nodes.templateTransform.outputVars.output": "Преобразованный контент", + "nodes.tool.authorizationRequired": "Требуется авторизация", "nodes.tool.authorize": "Авторизовать", "nodes.tool.inputVars": "Входные переменные", "nodes.tool.insertPlaceholder1": "Наберите или нажмите", @@ -1062,10 +1071,12 @@ "panel.change": "Изменить", "panel.changeBlock": "Изменить узел", "panel.checklist": "Контрольный список", + "panel.checklistDescription": "Устраните следующие проблемы перед публикацией", "panel.checklistResolved": "Все проблемы решены", "panel.checklistTip": "Убедитесь, что все проблемы решены перед публикацией", "panel.createdBy": "Создано ", "panel.goTo": "Перейти к", + "panel.goToFix": "Перейти к исправлению", "panel.helpLink": "Помощь", "panel.maximize": "Максимизировать холст", "panel.minimize": "Выйти из полноэкранного режима", diff --git a/web/i18n/sl-SI/app-debug.json b/web/i18n/sl-SI/app-debug.json index 948e3dbc67..cf3d6c9ec4 100644 --- a/web/i18n/sl-SI/app-debug.json +++ b/web/i18n/sl-SI/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "TEČI", "inputs.title": "Odpravljanje napak in predogled", "inputs.userInputField": "Uporabniško polje za vnos", + "manageModels": "Upravljanje modelov", "modelConfig.modeType.chat": "Chat", "modelConfig.modeType.completion": "Dokončati", "modelConfig.model": "Model", "modelConfig.setTone": "Nastavitev tona odzivov", "modelConfig.title": "Model in parametri", + "noModelProviderConfigured": "Ponudnik modelov ni konfiguriran", + "noModelProviderConfiguredTip": "Za začetek namestite ali konfigurirajte ponudnika modelov.", + "noModelSelected": "Noben model ni izbran", + "noModelSelectedTip": "Za nadaljevanje konfigurirajte model zgoraj.", "noResult": "Tukaj bo prikazan izhod.", "notSetAPIKey.description": "Ključ ponudnika LLM ni nastavljen. Pred odpravljanjem napak je treba nastaviti ključ.", "notSetAPIKey.settingBtn": "Pojdi v nastavitve", diff --git a/web/i18n/sl-SI/common.json b/web/i18n/sl-SI/common.json index 247874bb30..6ec4fe430c 100644 --- a/web/i18n/sl-SI/common.json +++ b/web/i18n/sl-SI/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Neavtorizirano", "modelProvider.buyQuota": "Kupi kvoto", "modelProvider.callTimes": "Število klicev", + "modelProvider.card.aiCreditsInUse": "Krediti AI v uporabi", + "modelProvider.card.aiCreditsOption": "Krediti AI", + "modelProvider.card.apiKeyOption": "API ključ", + "modelProvider.card.apiKeyRequired": "Potreben je API ključ", + "modelProvider.card.apiKeyUnavailableFallback": "API ključ ni na voljo, zdaj se uporabljajo krediti AI", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Preverite konfiguracijo API ključa za preklop nazaj", "modelProvider.card.buyQuota": "Kupi kvoto", "modelProvider.card.callTimes": "Časi klicev", + "modelProvider.card.creditsExhaustedDescription": "Prosimo, nadgradite načrt ali konfigurirajte API ključ", + "modelProvider.card.creditsExhaustedFallback": "Krediti AI so porabljeni, zdaj se uporablja API ključ", + "modelProvider.card.creditsExhaustedFallbackDescription": "Nadgradite načrt za nadaljevanje prednostne uporabe kreditov AI.", + "modelProvider.card.creditsExhaustedMessage": "Krediti AI so bili porabljeni", "modelProvider.card.modelAPI": "Modeli {{modelName}} uporabljajo API ključ.", "modelProvider.card.modelNotSupported": "Modeli {{modelName}} niso nameščeni.", "modelProvider.card.modelSupported": "Modeli {{modelName}} uporabljajo to kvoto.", + "modelProvider.card.noApiKeysDescription": "Dodajte API ključ za začetek uporabe lastnih poverilnic modela.", + "modelProvider.card.noApiKeysFallback": "Ni API ključev, namesto tega se uporabljajo krediti AI", + "modelProvider.card.noApiKeysTitle": "Še ni konfiguriranih API ključev", + "modelProvider.card.noAvailableUsage": "Ni razpoložljive uporabe", "modelProvider.card.onTrial": "Na preizkusu", "modelProvider.card.paid": "Plačano", "modelProvider.card.priorityUse": "Prednostna uporaba", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Odstrani API ključ", "modelProvider.card.tip": "Krediti za sporočila podpirajo modele od {{modelNames}}. Prednostno se bo uporabila plačana kvota. Brezplačna kvota se bo uporabila, ko bo plačana kvota porabljena.", "modelProvider.card.tokens": "Žetoni", + "modelProvider.card.unavailable": "Ni na voljo", + "modelProvider.card.upgradePlan": "nadgradite načrt", + "modelProvider.card.usageLabel": "Poraba", + "modelProvider.card.usagePriority": "Prednost porabe", + "modelProvider.card.usagePriorityTip": "Nastavite, kateri vir se bo najprej uporabil pri zaganjanju modelov.", "modelProvider.collapse": "Strni", "modelProvider.config": "Konfiguracija", "modelProvider.configLoadBalancing": "Konfiguracija uravnoteženja obremenitev", @@ -387,9 +406,11 @@ "modelProvider.model": "Model", "modelProvider.modelAndParameters": "Model in parametri", "modelProvider.modelHasBeenDeprecated": "Ta model je zastarel", + "modelProvider.modelSettings": "Nastavitve modela", "modelProvider.models": "Modeli", "modelProvider.modelsNum": "{{num}} modelov", "modelProvider.noModelFound": "Za {{model}} ni najden noben model", + "modelProvider.noneConfigured": "Konfigurirajte privzeti sistemski model za zaganjanje aplikacij", "modelProvider.notConfigured": "Sistemski model še ni popolnoma konfiguriran, nekatere funkcije morda ne bodo na voljo.", "modelProvider.parameters": "PARAMETRI", "modelProvider.parametersInvalidRemoved": "Nekateri parametri so neveljavni in so bili odstranjeni.", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Ponastavi {{date}}", "modelProvider.searchModel": "Model iskanja", "modelProvider.selectModel": "Izberite svoj model", + "modelProvider.selector.aiCredits": "Krediti AI", + "modelProvider.selector.apiKeyUnavailable": "API ključ ni na voljo", + "modelProvider.selector.apiKeyUnavailableTip": "API ključ je bil odstranjen. Prosimo, konfigurirajte nov API ključ.", + "modelProvider.selector.configure": "Konfiguriraj", + "modelProvider.selector.configureRequired": "Konfiguracija je obvezna", + "modelProvider.selector.creditsExhausted": "Krediti so porabljeni", + "modelProvider.selector.creditsExhaustedTip": "Vaši krediti AI so bili porabljeni. Prosimo, nadgradite načrt ali dodajte API ključ.", + "modelProvider.selector.disabled": "Onemogočeno", + "modelProvider.selector.discoverMoreInMarketplace": "Odkrijte več v Tržnici", "modelProvider.selector.emptySetting": "Prosimo, pojdite v nastavitve za konfiguracijo", "modelProvider.selector.emptyTip": "Ni razpoložljivih modelov", + "modelProvider.selector.fromMarketplace": "Iz Tržnice", + "modelProvider.selector.incompatible": "Nezdružljivo", + "modelProvider.selector.incompatibleTip": "Ta model ni na voljo v trenutni različici. Izberite drug razpoložljiv model.", + "modelProvider.selector.install": "Namesti", + "modelProvider.selector.modelProviderSettings": "Nastavitve ponudnika modelov", + "modelProvider.selector.noProviderConfigured": "Noben ponudnik modelov ni konfiguriran", + "modelProvider.selector.noProviderConfiguredDesc": "Poiščite v Tržnici za namestitev ali konfigurirajte ponudnike v nastavitvah.", + "modelProvider.selector.onlyCompatibleModelsShown": "Prikazani so samo združljivi modeli", "modelProvider.selector.rerankTip": "Prosimo, nastavite model za prerazvrstitev", "modelProvider.selector.tip": "Ta model je bil odstranjen. Prosimo, dodajte model ali izberite drugega.", "modelProvider.setupModelFirst": "Najprej nastavite svoj model", diff --git a/web/i18n/sl-SI/plugin.json b/web/i18n/sl-SI/plugin.json index 0f84e91c80..91f6c670bf 100644 --- a/web/i18n/sl-SI/plugin.json +++ b/web/i18n/sl-SI/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Odstrani vtičnik", "action.deleteContentLeft": "Ali želite odstraniti", "action.deleteContentRight": "vtičnik?", + "action.deleteSuccess": "Vtičnik je bil uspešno odstranjen", "action.pluginInfo": "Informacije o vtičniku", "action.usedInApps": "Ta vtičnik se uporablja v {{num}} aplikacijah.", "allCategories": "Vse kategorije", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Namestite", "detailPanel.operation.remove": "Odstrani", "detailPanel.operation.update": "Posodobitev", + "detailPanel.operation.updateTooltip": "Posodobite za dostop do najnovejših modelov.", "detailPanel.operation.viewDetail": "Oglej si podrobnosti", "detailPanel.serviceOk": "Storitve so v redu", "detailPanel.strategyNum": "{{num}} {{strategy}} VKLJUČENO", @@ -231,12 +233,18 @@ "source.local": "Lokalna paketna datoteka", "source.marketplace": "Tržnica", "task.clearAll": "Počisti vse", + "task.errorMsg.github": "Ta vtičnik ni bil samodejno nameščen.\nProsimo, namestite ga iz GitHuba.", + "task.errorMsg.marketplace": "Ta vtičnik ni bil samodejno nameščen.\nProsimo, namestite ga iz Tržnice.", + "task.errorMsg.unknown": "Ta vtičnik ni bil nameščen.\nIzvora vtičnika ni mogoče ugotoviti.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} vtičnikov ni uspelo namestiti, kliknite za ogled", + "task.installFromGithub": "Namestite iz GitHuba", + "task.installFromMarketplace": "Namestite iz Tržnice", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} vtičnikov ni uspelo namestiti", "task.installing": "Nameščanje vtičnikov.", + "task.installingHint": "Nameščanje... To lahko traja nekaj minut.", "task.installingWithError": "Namestitev {{installingLength}} vtičnikov, {{successLength}} uspešnih, {{errorLength}} neuspešnih", "task.installingWithSuccess": "Namestitev {{installingLength}} dodatkov, {{successLength}} uspešnih.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/sl-SI/workflow.json b/web/i18n/sl-SI/workflow.json index f11175d1d4..a20a30753d 100644 --- a/web/i18n/sl-SI/workflow.json +++ b/web/i18n/sl-SI/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "posodabljanje delovnega procesa", "error.startNodeRequired": "Prosimo, najprej dodajte začetni vozel pred {{operation}}", "errorMsg.authRequired": "Zahtevana je avtorizacija", + "errorMsg.configureModel": "Konfiguriraj model", "errorMsg.fieldRequired": "{{field}} je obvezno", "errorMsg.fields.code": "Koda", "errorMsg.fields.model": "Model", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Vizijska spremenljivka", "errorMsg.invalidJson": "{{field}} je neveljaven JSON", "errorMsg.invalidVariable": "Neveljavna spremenljivka", + "errorMsg.modelPluginNotInstalled": "Neveljavna spremenljivka. Konfigurirajte model za omogočanje te spremenljivke.", "errorMsg.noValidTool": "{{field}} ni izbranega veljavnega orodja", "errorMsg.rerankModelRequired": "Zahteva se konfigurirana model ponovnega razvrščanja.", "errorMsg.startNodeRequired": "Prosimo, najprej dodajte začetni vozel pred {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Velikost okna", "nodes.common.outputVars": "Izhodne spremenljivke", "nodes.common.pluginNotInstalled": "Vtičnik ni nameščen", + "nodes.common.pluginsNotInstalled": "{{count}} vtičnikov ni nameščenih", "nodes.common.retry.maxRetries": "maksimalno število poskusov", "nodes.common.retry.ms": "ms", "nodes.common.retry.retries": "{{num}} Poskusi", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Kosi", "nodes.knowledgeBase.chunksInputTip": "Vhodna spremenljivka vozlišča podatkovne baze je Chunks. Tip spremenljivke je objekt s specifično JSON shemo, ki mora biti skladna z izbrano strukturo kosov.", "nodes.knowledgeBase.chunksVariableIsRequired": "Spremenljivka Chunks je obvezna", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API ključ ni na voljo", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Krediti so porabljeni", + "nodes.knowledgeBase.embeddingModelIncompatible": "Nezdružljivo", "nodes.knowledgeBase.embeddingModelIsInvalid": "Vdelovalni model ni veljaven", "nodes.knowledgeBase.embeddingModelIsRequired": "Zahteva se vgrajevalni model", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Vdelovalni model ni konfiguriran", "nodes.knowledgeBase.indexMethodIsRequired": "Zahteva se indeksna metoda", + "nodes.knowledgeBase.notConfigured": "Ni konfigurirano", "nodes.knowledgeBase.rerankingModelIsInvalid": "Model prerazvrščanja ni veljaven", "nodes.knowledgeBase.rerankingModelIsRequired": "Potreben je model za ponovno razvrščanje", "nodes.knowledgeBase.retrievalSettingIsRequired": "Zahtevana je nastavitev pridobivanja", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Podpira samo Jinja2", "nodes.templateTransform.inputVars": "Vhodne spremenljivke", "nodes.templateTransform.outputVars.output": "Transformirana vsebina", + "nodes.tool.authorizationRequired": "Zahtevana je avtorizacija", "nodes.tool.authorize": "Pooblasti", "nodes.tool.inputVars": "Vhodne spremenljivke", "nodes.tool.insertPlaceholder1": "Vnesite ali pritisnite", @@ -1062,10 +1071,12 @@ "panel.change": "Spremeni", "panel.changeBlock": "Spremeni vozlišče", "panel.checklist": "Kontrolni seznam", + "panel.checklistDescription": "Rešite naslednje težave pred objavo", "panel.checklistResolved": "Vse težave so rešene", "panel.checklistTip": "Prepričajte se, da so vse težave rešene, preden objavite.", "panel.createdBy": "Ustvarjeno z", "panel.goTo": "Pojdi na", + "panel.goToFix": "Pojdi na popravi", "panel.helpLink": "Pomoč", "panel.maximize": "Maksimiziraj platno", "panel.minimize": "Izhod iz celotnega zaslona", diff --git a/web/i18n/th-TH/app-debug.json b/web/i18n/th-TH/app-debug.json index 53dc629a7f..c36db9ce18 100644 --- a/web/i18n/th-TH/app-debug.json +++ b/web/i18n/th-TH/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "วิ่ง", "inputs.title": "ดีบัก & ดูตัวอย่าง", "inputs.userInputField": "ฟิลด์ป้อนข้อมูลของผู้ใช้", + "manageModels": "จัดการโมเดล", "modelConfig.modeType.chat": "สนทนา", "modelConfig.modeType.completion": "สมบูรณ์", "modelConfig.model": "แบบ", "modelConfig.setTone": "กําหนดน้ําเสียงของการตอบกลับ", "modelConfig.title": "รุ่นและพารามิเตอร์", + "noModelProviderConfigured": "ยังไม่ได้กำหนดค่าผู้ให้บริการโมเดล", + "noModelProviderConfiguredTip": "ติดตั้งหรือกำหนดค่าผู้ให้บริการโมเดลเพื่อเริ่มต้นใช้งาน", + "noModelSelected": "ยังไม่ได้เลือกโมเดล", + "noModelSelectedTip": "กำหนดค่าโมเดลด้านบนเพื่อดำเนินการต่อ", "noResult": "ผลลัพธ์จะแสดงที่นี่", "notSetAPIKey.description": "ยังไม่ได้ตั้งค่าคีย์ผู้ให้บริการ LLM และจําเป็นต้องตั้งค่าก่อนการดีบัก", "notSetAPIKey.settingBtn": "ไปที่การตั้งค่า", diff --git a/web/i18n/th-TH/common.json b/web/i18n/th-TH/common.json index 71c9369599..6eed5eba93 100644 --- a/web/i18n/th-TH/common.json +++ b/web/i18n/th-TH/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "ไม่ได้รับอนุญาต", "modelProvider.buyQuota": "ซื้อโควต้า", "modelProvider.callTimes": "เวลาโทร", + "modelProvider.card.aiCreditsInUse": "กำลังใช้เครดิต AI", + "modelProvider.card.aiCreditsOption": "เครดิต AI", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "ต้องการ API Key", + "modelProvider.card.apiKeyUnavailableFallback": "API Key ไม่พร้อมใช้งาน กำลังใช้เครดิต AI แทน", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "ตรวจสอบการกำหนดค่า API Key ของคุณเพื่อสลับกลับ", "modelProvider.card.buyQuota": "ซื้อโควต้า", "modelProvider.card.callTimes": "เวลาโทร", + "modelProvider.card.creditsExhaustedDescription": "กรุณาอัปเกรดแพ็กเกจหรือกำหนดค่า API Key", + "modelProvider.card.creditsExhaustedFallback": "เครดิต AI หมดแล้ว กำลังใช้ API Key แทน", + "modelProvider.card.creditsExhaustedFallbackDescription": "อัปเกรดแพ็กเกจเพื่อกลับมาใช้เครดิต AI เป็นลำดับแรก", + "modelProvider.card.creditsExhaustedMessage": "เครดิต AI หมดแล้ว", "modelProvider.card.modelAPI": "โมเดล {{modelName}} กำลังใช้คีย์ API", "modelProvider.card.modelNotSupported": "โมเดล {{modelName}} ไม่ได้ติดตั้ง", "modelProvider.card.modelSupported": "โมเดล {{modelName}} กำลังใช้โควต้านี้", + "modelProvider.card.noApiKeysDescription": "เพิ่ม API Key เพื่อเริ่มใช้ข้อมูลรับรองโมเดลของคุณเอง", + "modelProvider.card.noApiKeysFallback": "ไม่มี API Key กำลังใช้เครดิต AI แทน", + "modelProvider.card.noApiKeysTitle": "ยังไม่ได้กำหนดค่า API Key", + "modelProvider.card.noAvailableUsage": "ไม่มีปริมาณการใช้งานที่พร้อมใช้", "modelProvider.card.onTrial": "ทดลองใช้", "modelProvider.card.paid": "จ่าย", "modelProvider.card.priorityUse": "ลําดับความสําคัญในการใช้งาน", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "ลบคีย์ API", "modelProvider.card.tip": "เครดิตข้อความรองรับโมเดลจาก {{modelNames}} จะให้ลำดับความสำคัญกับโควต้าที่ชำระแล้ว โควต้าฟรีจะถูกใช้หลังจากโควต้าที่ชำระแล้วหมด", "modelProvider.card.tokens": "โท เค็น", + "modelProvider.card.unavailable": "ไม่พร้อมใช้งาน", + "modelProvider.card.upgradePlan": "อัปเกรดแพ็กเกจ", + "modelProvider.card.usageLabel": "การใช้งาน", + "modelProvider.card.usagePriority": "ลำดับความสำคัญการใช้งาน", + "modelProvider.card.usagePriorityTip": "ตั้งค่าทรัพยากรที่จะใช้ก่อนเมื่อเรียกใช้โมเดล", "modelProvider.collapse": "ทรุด", "modelProvider.config": "กําหนดค่า", "modelProvider.configLoadBalancing": "กําหนดค่าโหลดบาลานซ์", @@ -387,9 +406,11 @@ "modelProvider.model": "แบบ", "modelProvider.modelAndParameters": "รุ่นและพารามิเตอร์", "modelProvider.modelHasBeenDeprecated": "โมเดลนี้เลิกใช้แล้ว", + "modelProvider.modelSettings": "การตั้งค่าโมเดล", "modelProvider.models": "รุ่น", "modelProvider.modelsNum": "{{num}} รุ่น", "modelProvider.noModelFound": "ไม่พบแบบจําลองสําหรับ {{model}}", + "modelProvider.noneConfigured": "กำหนดค่าโมเดลระบบเริ่มต้นเพื่อเรียกใช้แอปพลิเคชัน", "modelProvider.notConfigured": "โมเดลระบบยังไม่ได้รับการกําหนดค่าอย่างสมบูรณ์ และฟังก์ชันบางอย่างอาจไม่พร้อมใช้งาน", "modelProvider.parameters": "พารามิเตอร์", "modelProvider.parametersInvalidRemoved": "บางพารามิเตอร์ไม่ถูกต้องและถูกนำออก", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "รีเซ็ตเมื่อ {{date}}", "modelProvider.searchModel": "ค้นหารุ่น", "modelProvider.selectModel": "เลือกรุ่นของคุณ", + "modelProvider.selector.aiCredits": "เครดิต AI", + "modelProvider.selector.apiKeyUnavailable": "API Key ไม่พร้อมใช้งาน", + "modelProvider.selector.apiKeyUnavailableTip": "API Key ถูกลบแล้ว กรุณากำหนดค่า API Key ใหม่", + "modelProvider.selector.configure": "กำหนดค่า", + "modelProvider.selector.configureRequired": "ต้องกำหนดค่า", + "modelProvider.selector.creditsExhausted": "เครดิตหมดแล้ว", + "modelProvider.selector.creditsExhaustedTip": "เครดิต AI ของคุณหมดแล้ว กรุณาอัปเกรดแพ็กเกจหรือเพิ่ม API Key", + "modelProvider.selector.disabled": "ปิดใช้งาน", + "modelProvider.selector.discoverMoreInMarketplace": "ค้นพบเพิ่มเติมใน Marketplace", "modelProvider.selector.emptySetting": "โปรดไปที่การตั้งค่าเพื่อกําหนดค่า", "modelProvider.selector.emptyTip": "ไม่มีรุ่นที่พร้อมใช้งาน", + "modelProvider.selector.fromMarketplace": "จาก Marketplace", + "modelProvider.selector.incompatible": "ไม่รองรับ", + "modelProvider.selector.incompatibleTip": "โมเดลนี้ไม่พร้อมใช้งานในเวอร์ชันปัจจุบัน กรุณาเลือกโมเดลอื่นที่พร้อมใช้งาน", + "modelProvider.selector.install": "ติดตั้ง", + "modelProvider.selector.modelProviderSettings": "การตั้งค่าผู้ให้บริการโมเดล", + "modelProvider.selector.noProviderConfigured": "ยังไม่ได้กำหนดค่าผู้ให้บริการโมเดล", + "modelProvider.selector.noProviderConfiguredDesc": "เรียกดู Marketplace เพื่อติดตั้ง หรือกำหนดค่าผู้ให้บริการในการตั้งค่า", + "modelProvider.selector.onlyCompatibleModelsShown": "แสดงเฉพาะโมเดลที่รองรับเท่านั้น", "modelProvider.selector.rerankTip": "โปรดตั้งค่าโมเดล Rerank", "modelProvider.selector.tip": "รุ่นนี้ถูกลบออกแล้ว โปรดเพิ่มรุ่นหรือเลือกรุ่นอื่น", "modelProvider.setupModelFirst": "โปรดตั้งค่าโมเดลของคุณก่อน", diff --git a/web/i18n/th-TH/plugin.json b/web/i18n/th-TH/plugin.json index 1c5a544abc..0545d0d46b 100644 --- a/web/i18n/th-TH/plugin.json +++ b/web/i18n/th-TH/plugin.json @@ -3,6 +3,7 @@ "action.delete": "ลบปลั๊กอิน", "action.deleteContentLeft": "คุณต้องการลบ", "action.deleteContentRight": "ปลั๊กอิน?", + "action.deleteSuccess": "ลบปลั๊กอินสำเร็จแล้ว", "action.pluginInfo": "ข้อมูลปลั๊กอิน", "action.usedInApps": "ปลั๊กอินนี้ถูกใช้ในแอป {{num}}", "allCategories": "หมวดหมู่ทั้งหมด", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "ติดตั้ง", "detailPanel.operation.remove": "ถอด", "detailPanel.operation.update": "อัพเดต", + "detailPanel.operation.updateTooltip": "อัปเดตเพื่อเข้าถึงโมเดลล่าสุด", "detailPanel.operation.viewDetail": "ดูรายละเอียด", "detailPanel.serviceOk": "บริการตกลง", "detailPanel.strategyNum": "{{num}} {{strategy}} รวม", @@ -231,12 +233,18 @@ "source.local": "ไฟล์แพ็คเกจในเครื่อง", "source.marketplace": "ตลาด", "task.clearAll": "ล้างทั้งหมด", + "task.errorMsg.github": "ไม่สามารถติดตั้งปลั๊กอินนี้โดยอัตโนมัติได้\nกรุณาติดตั้งจาก GitHub", + "task.errorMsg.marketplace": "ไม่สามารถติดตั้งปลั๊กอินนี้โดยอัตโนมัติได้\nกรุณาติดตั้งจาก Marketplace", + "task.errorMsg.unknown": "ไม่สามารถติดตั้งปลั๊กอินนี้ได้\nไม่สามารถระบุแหล่งที่มาของปลั๊กอินได้", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} ปลั๊กอินติดตั้งไม่สําเร็จ คลิกเพื่อดู", + "task.installFromGithub": "ติดตั้งจาก GitHub", + "task.installFromMarketplace": "ติดตั้งจาก Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} ปลั๊กอินติดตั้งไม่สําเร็จ", "task.installing": "การติดตั้งปลั๊กอิน", + "task.installingHint": "กำลังติดตั้ง... อาจใช้เวลาสักครู่", "task.installingWithError": "การติดตั้งปลั๊กอิน {{installingLength}}, {{successLength}} สําเร็จ, {{errorLength}} ล้มเหลว", "task.installingWithSuccess": "การติดตั้งปลั๊กอิน {{installingLength}}, {{successLength}} สําเร็จ", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/th-TH/workflow.json b/web/i18n/th-TH/workflow.json index 84bfe5ed97..7819e884c3 100644 --- a/web/i18n/th-TH/workflow.json +++ b/web/i18n/th-TH/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "กำลังปรับปรุงเวิร์กโฟลว์", "error.startNodeRequired": "โปรดเพิ่มโหนดเริ่มต้นก่อน {{operation}}", "errorMsg.authRequired": "ต้องได้รับอนุญาต", + "errorMsg.configureModel": "กำหนดค่าโมเดล", "errorMsg.fieldRequired": "{{field}} เป็นสิ่งจําเป็น", "errorMsg.fields.code": "รหัส", "errorMsg.fields.model": "แบบ", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "ตัวแปรวิสัยทัศน์", "errorMsg.invalidJson": "{{field}} เป็น JSON ไม่ถูกต้อง", "errorMsg.invalidVariable": "ตัวแปรไม่ถูกต้อง", + "errorMsg.modelPluginNotInstalled": "ตัวแปรไม่ถูกต้อง กำหนดค่าโมเดลเพื่อเปิดใช้งานตัวแปรนี้", "errorMsg.noValidTool": "{{field}} ไม่ได้เลือกเครื่องมือที่ถูกต้อง", "errorMsg.rerankModelRequired": "ก่อนเปิด Rerank Model โปรดยืนยันว่าได้กําหนดค่าโมเดลสําเร็จในการตั้งค่า", "errorMsg.startNodeRequired": "โปรดเพิ่มโหนดเริ่มต้นก่อน {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "ขนาดหน้าต่าง", "nodes.common.outputVars": "ตัวแปรเอาต์พุต", "nodes.common.pluginNotInstalled": "ปลั๊กอินไม่ได้ติดตั้ง", + "nodes.common.pluginsNotInstalled": "{{count}} ปลั๊กอินยังไม่ได้ติดตั้ง", "nodes.common.retry.maxRetries": "การลองซ้ําสูงสุด", "nodes.common.retry.ms": "นางสาว", "nodes.common.retry.retries": "{{num}} ลอง", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "ชิ้นส่วน", "nodes.knowledgeBase.chunksInputTip": "ตัวแปรนำเข้าของโหนดฐานความรู้คือ Chunks ตัวแปรประเภทเป็นอ็อบเจ็กต์ที่มี JSON Schema เฉพาะซึ่งต้องสอดคล้องกับโครงสร้างชิ้นส่วนที่เลือกไว้.", "nodes.knowledgeBase.chunksVariableIsRequired": "ตัวแปร Chunks เป็นสิ่งจำเป็น", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key ไม่พร้อมใช้งาน", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "เครดิตหมดแล้ว", + "nodes.knowledgeBase.embeddingModelIncompatible": "ไม่รองรับ", "nodes.knowledgeBase.embeddingModelIsInvalid": "แบบจำลองการฝังไม่ถูกต้อง", "nodes.knowledgeBase.embeddingModelIsRequired": "จำเป็นต้องใช้โมเดลฝัง", + "nodes.knowledgeBase.embeddingModelNotConfigured": "ยังไม่ได้กำหนดค่าโมเดล Embedding", "nodes.knowledgeBase.indexMethodIsRequired": "ต้องใช้วิธีการจัดทําดัชนี", + "nodes.knowledgeBase.notConfigured": "ยังไม่ได้กำหนดค่า", "nodes.knowledgeBase.rerankingModelIsInvalid": "โมเดลการจัดอันดับใหม่ไม่ถูกต้อง", "nodes.knowledgeBase.rerankingModelIsRequired": "จำเป็นต้องมีโมเดลการจัดอันดับใหม่", "nodes.knowledgeBase.retrievalSettingIsRequired": "จําเป็นต้องมีการตั้งค่าการดึงข้อมูล", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "รองรับเฉพาะ Jinja2", "nodes.templateTransform.inputVars": "ตัวแปรอินพุต", "nodes.templateTransform.outputVars.output": "เนื้อหาที่แปลงโฉม", + "nodes.tool.authorizationRequired": "ต้องการการอนุญาต", "nodes.tool.authorize": "อนุญาต", "nodes.tool.inputVars": "ตัวแปรอินพุต", "nodes.tool.insertPlaceholder1": "พิมพ์หรือลงทะเบียน", @@ -1062,10 +1071,12 @@ "panel.change": "เปลี่ยน", "panel.changeBlock": "เปลี่ยนโหนด", "panel.checklist": "ตรวจ สอบ", + "panel.checklistDescription": "กรุณาแก้ไขปัญหาต่อไปนี้ก่อนเผยแพร่", "panel.checklistResolved": "ปัญหาทั้งหมดได้รับการแก้ไขแล้ว", "panel.checklistTip": "ตรวจสอบให้แน่ใจว่าปัญหาทั้งหมดได้รับการแก้ไขแล้วก่อนที่จะเผยแพร่", "panel.createdBy": "สร้างโดย", "panel.goTo": "ไปที่", + "panel.goToFix": "ไปแก้ไข", "panel.helpLink": "วิธีใช้", "panel.maximize": "เพิ่มประสิทธิภาพผ้าใบ", "panel.minimize": "ออกจากโหมดเต็มหน้าจอ", diff --git a/web/i18n/uk-UA/app-debug.json b/web/i18n/uk-UA/app-debug.json index 74dcc72efd..dd7ec2de8f 100644 --- a/web/i18n/uk-UA/app-debug.json +++ b/web/i18n/uk-UA/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "ЗАПУСТИТИ", "inputs.title": "Налагодження та попередній перегляд", "inputs.userInputField": "Поле введення користувача", + "manageModels": "Керування моделями", "modelConfig.modeType.chat": "Чат", "modelConfig.modeType.completion": "Завершення", "modelConfig.model": "Модель", "modelConfig.setTone": "Встановити тон відповідей", "modelConfig.title": "Модель і параметри", + "noModelProviderConfigured": "Постачальник моделей не налаштований", + "noModelProviderConfiguredTip": "Встановіть або налаштуйте постачальника моделей, щоб почати роботу.", + "noModelSelected": "Модель не вибрана", + "noModelSelectedTip": "налаштуйте модель вище, щоб продовжити.", "noResult": "Тут буде відображено вихідні дані.", "notSetAPIKey.description": "Ключ провайдера LLM не встановлено, і його потрібно встановити перед налагодженням.", "notSetAPIKey.settingBtn": "Перейти до налаштувань", diff --git a/web/i18n/uk-UA/common.json b/web/i18n/uk-UA/common.json index 6745644573..806cbede3d 100644 --- a/web/i18n/uk-UA/common.json +++ b/web/i18n/uk-UA/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Несанкціоновано", "modelProvider.buyQuota": "Придбати квоту", "modelProvider.callTimes": "Кількість викликів", + "modelProvider.card.aiCreditsInUse": "Використовуються AI-кредити", + "modelProvider.card.aiCreditsOption": "AI-кредити", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "Потрібен API-ключ", + "modelProvider.card.apiKeyUnavailableFallback": "API Key недоступний, використовуються AI-кредити", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Перевірте конфігурацію API-ключа, щоб повернутися до нього", "modelProvider.card.buyQuota": "Придбати квоту", "modelProvider.card.callTimes": "Кількість викликів", + "modelProvider.card.creditsExhaustedDescription": "Будь ласка, оновіть свій план або налаштуйте API-ключ", + "modelProvider.card.creditsExhaustedFallback": "AI-кредити вичерпано, використовується API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Оновіть свій план, щоб відновити пріоритет AI-кредитів.", + "modelProvider.card.creditsExhaustedMessage": "AI-кредити вичерпано", "modelProvider.card.modelAPI": "Моделі {{modelName}} використовують API-ключ.", "modelProvider.card.modelNotSupported": "Моделі {{modelName}} не встановлено.", "modelProvider.card.modelSupported": "Моделі {{modelName}} використовують цю квоту.", + "modelProvider.card.noApiKeysDescription": "Додайте API-ключ, щоб почати використовувати власні облікові дані моделі.", + "modelProvider.card.noApiKeysFallback": "API-ключі відсутні, використовуються AI-кредити", + "modelProvider.card.noApiKeysTitle": "API-ключі ще не налаштовані", + "modelProvider.card.noAvailableUsage": "Немає доступного використання", "modelProvider.card.onTrial": "У пробному періоді", "modelProvider.card.paid": "Оплачено", "modelProvider.card.priorityUse": "Пріоритетне використання", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Видалити ключ API", "modelProvider.card.tip": "Кредити повідомлень підтримують моделі від {{modelNames}}. Пріоритет буде надано оплаченій квоті. Безкоштовна квота буде використовуватися після вичерпання платної квоти.", "modelProvider.card.tokens": "Токени", + "modelProvider.card.unavailable": "Недоступно", + "modelProvider.card.upgradePlan": "оновіть свій план", + "modelProvider.card.usageLabel": "Використання", + "modelProvider.card.usagePriority": "Пріоритет використання", + "modelProvider.card.usagePriorityTip": "Встановіть, який ресурс використовувати першим при запуску моделей.", "modelProvider.collapse": "Згорнути", "modelProvider.config": "Налаштування", "modelProvider.configLoadBalancing": "Балансування навантаження конфігурації", @@ -387,9 +406,11 @@ "modelProvider.model": "Модель", "modelProvider.modelAndParameters": "Модель та параметри", "modelProvider.modelHasBeenDeprecated": "Ця модель вважається застарілою", + "modelProvider.modelSettings": "Налаштування моделі", "modelProvider.models": "Моделі", "modelProvider.modelsNum": "{{num}} моделей", "modelProvider.noModelFound": "Модель для {{model}} не знайдено", + "modelProvider.noneConfigured": "Налаштуйте системну модель за замовчуванням для запуску застосунків", "modelProvider.notConfigured": "Системну модель ще не повністю налаштовано, і деякі функції можуть бути недоступні.", "modelProvider.parameters": "ПАРАМЕТРИ", "modelProvider.parametersInvalidRemoved": "Деякі параметри є недійсними і були видалені", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Скидання {{date}}", "modelProvider.searchModel": "Пошукова модель", "modelProvider.selectModel": "Виберіть свою модель", + "modelProvider.selector.aiCredits": "AI-кредити", + "modelProvider.selector.apiKeyUnavailable": "API Key недоступний", + "modelProvider.selector.apiKeyUnavailableTip": "API-ключ було видалено. Будь ласка, налаштуйте новий API-ключ.", + "modelProvider.selector.configure": "Налаштувати", + "modelProvider.selector.configureRequired": "Потрібне налаштування", + "modelProvider.selector.creditsExhausted": "Кредити вичерпано", + "modelProvider.selector.creditsExhaustedTip": "Ваші AI-кредити вичерпано. Будь ласка, оновіть свій план або додайте API-ключ.", + "modelProvider.selector.disabled": "Вимкнено", + "modelProvider.selector.discoverMoreInMarketplace": "Знайти більше в Marketplace", "modelProvider.selector.emptySetting": "Перейдіть до налаштувань, щоб налаштувати", "modelProvider.selector.emptyTip": "Доступні моделі відсутні", + "modelProvider.selector.fromMarketplace": "З Marketplace", + "modelProvider.selector.incompatible": "Несумісно", + "modelProvider.selector.incompatibleTip": "Ця модель недоступна в поточній версії. Будь ласка, виберіть іншу доступну модель.", + "modelProvider.selector.install": "Встановити", + "modelProvider.selector.modelProviderSettings": "Налаштування постачальника моделей", + "modelProvider.selector.noProviderConfigured": "Постачальник моделей не налаштований", + "modelProvider.selector.noProviderConfiguredDesc": "Перегляньте Marketplace для встановлення або налаштуйте постачальників у параметрах.", + "modelProvider.selector.onlyCompatibleModelsShown": "Показано лише сумісні моделі", "modelProvider.selector.rerankTip": "Будь ласка, налаштуйте модель повторного ранжування", "modelProvider.selector.tip": "Цю модель було видалено. Будь ласка, додайте модель або виберіть іншу.", "modelProvider.setupModelFirst": "Будь ласка, спочатку налаштуйте свою модель", diff --git a/web/i18n/uk-UA/plugin.json b/web/i18n/uk-UA/plugin.json index 98bebf0ca8..603a8226a3 100644 --- a/web/i18n/uk-UA/plugin.json +++ b/web/i18n/uk-UA/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Видалити плагін", "action.deleteContentLeft": "Чи хотіли б ви видалити", "action.deleteContentRight": "плагін?", + "action.deleteSuccess": "Плагін успішно видалено", "action.pluginInfo": "Інформація про плагін", "action.usedInApps": "Цей плагін використовується в додатках {{num}}.", "allCategories": "Всі категорії", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Інсталювати", "detailPanel.operation.remove": "Видалити", "detailPanel.operation.update": "Оновлювати", + "detailPanel.operation.updateTooltip": "Оновіть, щоб отримати доступ до найновіших моделей.", "detailPanel.operation.viewDetail": "Переглянути деталі", "detailPanel.serviceOk": "Сервіс працює", "detailPanel.strategyNum": "{{num}} {{strategy}} ВКЛЮЧЕНІ", @@ -231,12 +233,18 @@ "source.local": "Файл локального пакета", "source.marketplace": "Ринку", "task.clearAll": "Очистити все", + "task.errorMsg.github": "Цей плагін не вдалося встановити автоматично.\nБудь ласка, встановіть його з GitHub.", + "task.errorMsg.marketplace": "Цей плагін не вдалося встановити автоматично.\nБудь ласка, встановіть його з Marketplace.", + "task.errorMsg.unknown": "Не вдалося встановити цей плагін.\nДжерело плагіна не вдалося визначити.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "Плагіни {{errorLength}} не вдалося встановити, натисніть, щоб переглянути", + "task.installFromGithub": "Встановити з GitHub", + "task.installFromMarketplace": "Встановити з Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "Плагіни {{errorLength}} не вдалося встановити", "task.installing": "Встановлення плагінів.", + "task.installingHint": "Встановлення... Це може зайняти кілька хвилин.", "task.installingWithError": "Не вдалося встановити плагіни {{installingLength}}, успіх {{successLength}}, {{errorLength}}", "task.installingWithSuccess": "Встановлення плагінів {{installingLength}}, успіх {{successLength}}.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/uk-UA/workflow.json b/web/i18n/uk-UA/workflow.json index e9365303f2..4fa95f6d57 100644 --- a/web/i18n/uk-UA/workflow.json +++ b/web/i18n/uk-UA/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "оновлення робочого процесу", "error.startNodeRequired": "Будь ласка, спершу додайте стартовий вузол перед {{operation}}", "errorMsg.authRequired": "Потрібна авторизація", + "errorMsg.configureModel": "Налаштуйте модель", "errorMsg.fieldRequired": "{{field}} є обов'язковим", "errorMsg.fields.code": "Код", "errorMsg.fields.model": "Модель", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Змінна зору", "errorMsg.invalidJson": "{{field}} є недійсним JSON", "errorMsg.invalidVariable": "Недійсна змінна", + "errorMsg.modelPluginNotInstalled": "Недійсна змінна. Налаштуйте модель, щоб увімкнути цю змінну.", "errorMsg.noValidTool": "{{field}} не вибрано дійсного інструменту", "errorMsg.rerankModelRequired": "Перед увімкненням Rerank Model, будь ласка, підтвердьте, що модель успішно налаштована в налаштуваннях.", "errorMsg.startNodeRequired": "Будь ласка, спершу додайте стартовий вузол перед {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Розмір вікна", "nodes.common.outputVars": "Змінні виходу", "nodes.common.pluginNotInstalled": "Плагін не встановлений", + "nodes.common.pluginsNotInstalled": "{{count}} плагінів не встановлено", "nodes.common.retry.maxRetries": "Максимальна кількість повторних спроб", "nodes.common.retry.ms": "МС", "nodes.common.retry.retries": "{{num}} Спроб", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Частини", "nodes.knowledgeBase.chunksInputTip": "Вхідна змінна вузла бази знань - це Частини. Тип змінної - об'єкт з певною JSON-схемою, яка повинна відповідати вибраній структурі частин.", "nodes.knowledgeBase.chunksVariableIsRequired": "Змінна chunks є обов'язковою", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API-ключ недоступний", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Кредити вичерпано", + "nodes.knowledgeBase.embeddingModelIncompatible": "Несумісно", "nodes.knowledgeBase.embeddingModelIsInvalid": "Модель вбудовування недійсна", "nodes.knowledgeBase.embeddingModelIsRequired": "Потрібна модель вбудовування", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Модель ембедингу не налаштована", "nodes.knowledgeBase.indexMethodIsRequired": "Обов'язковий індексний метод", + "nodes.knowledgeBase.notConfigured": "Не налаштовано", "nodes.knowledgeBase.rerankingModelIsInvalid": "Модель переналаштування недійсна", "nodes.knowledgeBase.rerankingModelIsRequired": "Потрібна модель повторного ранжування", "nodes.knowledgeBase.retrievalSettingIsRequired": "Потрібне налаштування для отримання", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Підтримує лише Jinja2", "nodes.templateTransform.inputVars": "Вхідні змінні", "nodes.templateTransform.outputVars.output": "Трансформований вміст", + "nodes.tool.authorizationRequired": "Потрібна авторизація", "nodes.tool.authorize": "Уповноважити", "nodes.tool.inputVars": "Вхідні змінні", "nodes.tool.insertPlaceholder1": "Введіть або натисніть", @@ -1062,10 +1071,12 @@ "panel.change": "Змінити", "panel.changeBlock": "Змінити вузол", "panel.checklist": "Контрольний список", + "panel.checklistDescription": "Вирішіть наступні проблеми перед публікацією", "panel.checklistResolved": "Всі проблеми вирішені", "panel.checklistTip": "Переконайтеся, що всі проблеми вирішені перед публікацією", "panel.createdBy": "Створено ", "panel.goTo": "Перейти до", + "panel.goToFix": "Перейти до виправлення", "panel.helpLink": "Довідковий центр", "panel.maximize": "Максимізувати полотно", "panel.minimize": "Вийти з повноекранного режиму", diff --git a/web/i18n/vi-VN/app-debug.json b/web/i18n/vi-VN/app-debug.json index d533a02370..92037650cb 100644 --- a/web/i18n/vi-VN/app-debug.json +++ b/web/i18n/vi-VN/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "CHẠY", "inputs.title": "Gỡ lỗi và xem trước", "inputs.userInputField": "Trường nhập liệu người dùng", + "manageModels": "Quản lý mô hình", "modelConfig.modeType.chat": "Trò chuyện", "modelConfig.modeType.completion": "Hoàn thành", "modelConfig.model": "Mô hình", "modelConfig.setTone": "Thiết lập giọng điệu của phản hồi", "modelConfig.title": "Mô hình và tham số", + "noModelProviderConfigured": "Chưa cấu hình nhà cung cấp mô hình", + "noModelProviderConfiguredTip": "Cài đặt hoặc cấu hình nhà cung cấp mô hình để bắt đầu.", + "noModelSelected": "Chưa chọn mô hình", + "noModelSelectedTip": "cấu hình mô hình ở trên để tiếp tục.", "noResult": "Đầu ra sẽ được hiển thị ở đây.", "notSetAPIKey.description": "Chưa thiết lập khóa API của nhà cung cấp LLM. Cần thiết lập trước khi gỡ lỗi.", "notSetAPIKey.settingBtn": "Đi đến cài đặt", diff --git a/web/i18n/vi-VN/common.json b/web/i18n/vi-VN/common.json index a47ee6da57..820bfdfdab 100644 --- a/web/i18n/vi-VN/common.json +++ b/web/i18n/vi-VN/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Không có quyền truy cập", "modelProvider.buyQuota": "Mua Quyền lợi", "modelProvider.callTimes": "Số lần gọi", + "modelProvider.card.aiCreditsInUse": "Đang sử dụng AI credits", + "modelProvider.card.aiCreditsOption": "AI credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "Yêu cầu API key", + "modelProvider.card.apiKeyUnavailableFallback": "API Key không khả dụng, đang sử dụng AI credits", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Kiểm tra cấu hình API key để chuyển lại", "modelProvider.card.buyQuota": "Mua Quota", "modelProvider.card.callTimes": "Số lần gọi", + "modelProvider.card.creditsExhaustedDescription": "Vui lòng nâng cấp gói dịch vụ hoặc cấu hình API key", + "modelProvider.card.creditsExhaustedFallback": "AI credits đã hết, đang sử dụng API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Nâng cấp gói dịch vụ để khôi phục ưu tiên AI credits.", + "modelProvider.card.creditsExhaustedMessage": "AI credits đã hết", "modelProvider.card.modelAPI": "Các mô hình {{modelName}} đang sử dụng Khóa API.", "modelProvider.card.modelNotSupported": "Các mô hình {{modelName}} chưa được cài đặt.", "modelProvider.card.modelSupported": "Các mô hình {{modelName}} đang sử dụng hạn mức này.", + "modelProvider.card.noApiKeysDescription": "Thêm API key để bắt đầu sử dụng thông tin xác thực mô hình của bạn.", + "modelProvider.card.noApiKeysFallback": "Không có API key, sử dụng AI credits thay thế", + "modelProvider.card.noApiKeysTitle": "Chưa cấu hình API key", + "modelProvider.card.noAvailableUsage": "Không có lượt sử dụng khả dụng", "modelProvider.card.onTrial": "Thử nghiệm", "modelProvider.card.paid": "Đã thanh toán", "modelProvider.card.priorityUse": "Ưu tiên sử dụng", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Remove API Key", "modelProvider.card.tip": "Tín dụng tin nhắn hỗ trợ các mô hình từ {{modelNames}}. Ưu tiên sẽ được trao cho hạn ngạch đã thanh toán. Hạn ngạch miễn phí sẽ được sử dụng sau khi hết hạn ngạch trả phí.", "modelProvider.card.tokens": "Tokens", + "modelProvider.card.unavailable": "Không khả dụng", + "modelProvider.card.upgradePlan": "nâng cấp gói dịch vụ", + "modelProvider.card.usageLabel": "Sử dụng", + "modelProvider.card.usagePriority": "Ưu tiên sử dụng", + "modelProvider.card.usagePriorityTip": "Đặt tài nguyên sử dụng trước khi chạy mô hình.", "modelProvider.collapse": "Thu gọn", "modelProvider.config": "Cấu hình", "modelProvider.configLoadBalancing": "Cấu hình cân bằng tải", @@ -387,9 +406,11 @@ "modelProvider.model": "Mô hình", "modelProvider.modelAndParameters": "Mô hình và Tham số", "modelProvider.modelHasBeenDeprecated": "Mô hình này đã bị phản đối", + "modelProvider.modelSettings": "Cài đặt mô hình", "modelProvider.models": "Mô hình", "modelProvider.modelsNum": "{{num}} Mô hình", "modelProvider.noModelFound": "Không tìm thấy mô hình cho {{model}}", + "modelProvider.noneConfigured": "Cấu hình mô hình hệ thống mặc định để chạy ứng dụng", "modelProvider.notConfigured": "Mô hình hệ thống vẫn chưa được cấu hình hoàn toàn và một số chức năng có thể không khả dụng.", "modelProvider.parameters": "THAM SỐ", "modelProvider.parametersInvalidRemoved": "Một số tham số không hợp lệ và đã được loại bỏ", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Đặt lại vào {{date}}", "modelProvider.searchModel": "Mô hình tìm kiếm", "modelProvider.selectModel": "Chọn mô hình của bạn", + "modelProvider.selector.aiCredits": "AI credits", + "modelProvider.selector.apiKeyUnavailable": "API Key không khả dụng", + "modelProvider.selector.apiKeyUnavailableTip": "API key đã bị xóa. Vui lòng cấu hình API key mới.", + "modelProvider.selector.configure": "Cấu hình", + "modelProvider.selector.configureRequired": "Cần cấu hình", + "modelProvider.selector.creditsExhausted": "Credits đã hết", + "modelProvider.selector.creditsExhaustedTip": "AI credits của bạn đã hết. Vui lòng nâng cấp gói dịch vụ hoặc thêm API key.", + "modelProvider.selector.disabled": "Đã tắt", + "modelProvider.selector.discoverMoreInMarketplace": "Khám phá thêm trên Marketplace", "modelProvider.selector.emptySetting": "Vui lòng vào cài đặt để cấu hình", "modelProvider.selector.emptyTip": "Không có mô hình khả dụng", + "modelProvider.selector.fromMarketplace": "Từ Marketplace", + "modelProvider.selector.incompatible": "Không tương thích", + "modelProvider.selector.incompatibleTip": "Mô hình này không khả dụng trong phiên bản hiện tại. Vui lòng chọn mô hình khả dụng khác.", + "modelProvider.selector.install": "Cài đặt", + "modelProvider.selector.modelProviderSettings": "Cài đặt nhà cung cấp mô hình", + "modelProvider.selector.noProviderConfigured": "Chưa cấu hình nhà cung cấp mô hình", + "modelProvider.selector.noProviderConfiguredDesc": "Duyệt Marketplace để cài đặt hoặc cấu hình nhà cung cấp trong phần cài đặt.", + "modelProvider.selector.onlyCompatibleModelsShown": "Chỉ hiển thị các mô hình tương thích", "modelProvider.selector.rerankTip": "Vui lòng thiết lập mô hình sắp xếp lại", "modelProvider.selector.tip": "Mô hình này đã bị xóa. Vui lòng thêm một mô hình hoặc chọn mô hình khác.", "modelProvider.setupModelFirst": "Vui lòng thiết lập mô hình của bạn trước", diff --git a/web/i18n/vi-VN/plugin.json b/web/i18n/vi-VN/plugin.json index 40cc20b7b0..1c1b37bba3 100644 --- a/web/i18n/vi-VN/plugin.json +++ b/web/i18n/vi-VN/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Xóa plugin", "action.deleteContentLeft": "Bạn có muốn xóa", "action.deleteContentRight": "plugin?", + "action.deleteSuccess": "Đã xóa plugin thành công", "action.pluginInfo": "Thông tin plugin", "action.usedInApps": "Plugin này đang được sử dụng trong các ứng dụng {{num}}.", "allCategories": "Tất cả các danh mục", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Cài đặt", "detailPanel.operation.remove": "Triệt", "detailPanel.operation.update": "Cập nhật", + "detailPanel.operation.updateTooltip": "Cập nhật để truy cập các mô hình mới nhất.", "detailPanel.operation.viewDetail": "xem chi tiết", "detailPanel.serviceOk": "Dịch vụ OK", "detailPanel.strategyNum": "{{num}} {{strategy}} BAO GỒM", @@ -231,12 +233,18 @@ "source.local": "Tệp gói cục bộ", "source.marketplace": "Chợ", "task.clearAll": "Xóa tất cả", + "task.errorMsg.github": "Không thể cài đặt plugin này tự động.\nVui lòng cài đặt từ GitHub.", + "task.errorMsg.marketplace": "Không thể cài đặt plugin này tự động.\nVui lòng cài đặt từ Marketplace.", + "task.errorMsg.unknown": "Không thể cài đặt plugin này.\nKhông xác định được nguồn plugin.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} plugin không cài đặt được, nhấp để xem", + "task.installFromGithub": "Cài đặt từ GitHub", + "task.installFromMarketplace": "Cài đặt từ Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} plugin không cài đặt được", "task.installing": "Đang cài đặt plugin.", + "task.installingHint": "Đang cài đặt... Quá trình này có thể mất vài phút.", "task.installingWithError": "Cài đặt {{installingLength}} plugins, {{successLength}} thành công, {{errorLength}} không thành công", "task.installingWithSuccess": "Cài đặt {{installingLength}} plugins, {{successLength}} thành công.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/vi-VN/workflow.json b/web/i18n/vi-VN/workflow.json index 3bdd86e231..3a8bdbaaf1 100644 --- a/web/i18n/vi-VN/workflow.json +++ b/web/i18n/vi-VN/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "cập nhật quy trình công việc", "error.startNodeRequired": "Vui lòng thêm một nút bắt đầu trước {{operation}}", "errorMsg.authRequired": "Yêu cầu xác thực", + "errorMsg.configureModel": "Cấu hình mô hình", "errorMsg.fieldRequired": "{{field}} là bắt buộc", "errorMsg.fields.code": "Mã", "errorMsg.fields.model": "Mô hình", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Biến tầm nhìn", "errorMsg.invalidJson": "{{field}} là JSON không hợp lệ", "errorMsg.invalidVariable": "Biến không hợp lệ", + "errorMsg.modelPluginNotInstalled": "Biến không hợp lệ. Cấu hình mô hình để kích hoạt biến này.", "errorMsg.noValidTool": "{{field}} không chọn công cụ hợp lệ nào", "errorMsg.rerankModelRequired": "Trước khi bật Mô hình xếp hạng lại, vui lòng xác nhận rằng mô hình đã được định cấu hình thành công trong cài đặt.", "errorMsg.startNodeRequired": "Vui lòng thêm một nút bắt đầu trước {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Kích thước cửa sổ", "nodes.common.outputVars": "Biến đầu ra", "nodes.common.pluginNotInstalled": "Plugin chưa được cài đặt", + "nodes.common.pluginsNotInstalled": "{{count}} plugin chưa được cài đặt", "nodes.common.retry.maxRetries": "Số lần thử lại tối đa", "nodes.common.retry.ms": "Ms", "nodes.common.retry.retries": "{{num}} Thử lại", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Mảnh", "nodes.knowledgeBase.chunksInputTip": "Biến đầu vào của nút cơ sở tri thức là Chunks. Loại biến là một đối tượng với một JSON Schema cụ thể mà phải nhất quán với cấu trúc chunk đã chọn.", "nodes.knowledgeBase.chunksVariableIsRequired": "Biến Chunks là bắt buộc", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API key không khả dụng", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Credits đã hết", + "nodes.knowledgeBase.embeddingModelIncompatible": "Không tương thích", "nodes.knowledgeBase.embeddingModelIsInvalid": "Mô hình nhúng không hợp lệ", "nodes.knowledgeBase.embeddingModelIsRequired": "Cần có mô hình nhúng", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Chưa cấu hình mô hình nhúng", "nodes.knowledgeBase.indexMethodIsRequired": "Phương pháp chỉ mục là bắt buộc", + "nodes.knowledgeBase.notConfigured": "Chưa cấu hình", "nodes.knowledgeBase.rerankingModelIsInvalid": "Mô hình xếp hạng lại không hợp lệ", "nodes.knowledgeBase.rerankingModelIsRequired": "Cần có mô hình sắp xếp lại", "nodes.knowledgeBase.retrievalSettingIsRequired": "Cài đặt truy xuất là bắt buộc", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Chỉ hỗ trợ Jinja2", "nodes.templateTransform.inputVars": "Biến đầu vào", "nodes.templateTransform.outputVars.output": "Nội dung chuyển đổi", + "nodes.tool.authorizationRequired": "Yêu cầu ủy quyền", "nodes.tool.authorize": "Ủy quyền", "nodes.tool.inputVars": "Biến đầu vào", "nodes.tool.insertPlaceholder1": "Gõ hoặc nhấn", @@ -1062,10 +1071,12 @@ "panel.change": "Thay đổi", "panel.changeBlock": "Thay đổi Node", "panel.checklist": "Danh sách kiểm tra", + "panel.checklistDescription": "Giải quyết các vấn đề sau trước khi xuất bản", "panel.checklistResolved": "Tất cả các vấn đề đã được giải quyết", "panel.checklistTip": "Đảm bảo rằng tất cả các vấn đề đã được giải quyết trước khi xuất bản", "panel.createdBy": "Tạo bởi ", "panel.goTo": "Đi tới", + "panel.goToFix": "Đi đến sửa lỗi", "panel.helpLink": "Trung tâm trợ giúp", "panel.maximize": "Tối đa hóa Canvas", "panel.minimize": "Thoát chế độ toàn màn hình", diff --git a/web/i18n/zh-Hant/app-debug.json b/web/i18n/zh-Hant/app-debug.json index ab7286691f..3058cd8458 100644 --- a/web/i18n/zh-Hant/app-debug.json +++ b/web/i18n/zh-Hant/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "執行", "inputs.title": "除錯與預覽", "inputs.userInputField": "使用者輸入", + "manageModels": "管理模型", "modelConfig.modeType.chat": "對話型", "modelConfig.modeType.completion": "補全型", "modelConfig.model": "語言模型", "modelConfig.setTone": "模型設定", "modelConfig.title": "模型及引數", + "noModelProviderConfigured": "未配置模型供應商", + "noModelProviderConfiguredTip": "請先安裝或配置模型供應商以開始使用。", + "noModelSelected": "未選擇模型", + "noModelSelectedTip": "請先在上方配置模型以繼續。", "noResult": "輸出將顯示在此處。", "notSetAPIKey.description": "在除錯之前需要設定 LLM 提供者的金鑰。", "notSetAPIKey.settingBtn": "去設定", diff --git a/web/i18n/zh-Hant/common.json b/web/i18n/zh-Hant/common.json index 85ff9cc687..9317f68f82 100644 --- a/web/i18n/zh-Hant/common.json +++ b/web/i18n/zh-Hant/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "未經授權", "modelProvider.buyQuota": "購買額度", "modelProvider.callTimes": "呼叫次數", + "modelProvider.card.aiCreditsInUse": "AI 額度使用中", + "modelProvider.card.aiCreditsOption": "AI 額度", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "需要配置 API Key", + "modelProvider.card.apiKeyUnavailableFallback": "API Key 不可用,正在使用 AI 額度", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "檢查你的 API Key 配置以切換回來", "modelProvider.card.buyQuota": "購買額度", "modelProvider.card.callTimes": "呼叫次數", + "modelProvider.card.creditsExhaustedDescription": "請升級方案或配置 API Key", + "modelProvider.card.creditsExhaustedFallback": "AI 額度已用盡,正在使用 API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "升級方案以恢復 AI 額度優先使用。", + "modelProvider.card.creditsExhaustedMessage": "AI 額度已用盡", "modelProvider.card.modelAPI": "{{modelName}} 模型正在使用 API Key。", "modelProvider.card.modelNotSupported": "{{modelName}} 模型未安裝。", "modelProvider.card.modelSupported": "{{modelName}} 模型正在使用此配額。", + "modelProvider.card.noApiKeysDescription": "新增 API Key 以使用自有模型憑證。", + "modelProvider.card.noApiKeysFallback": "未配置 API Key,正在使用 AI 額度", + "modelProvider.card.noApiKeysTitle": "尚未配置 API Key", + "modelProvider.card.noAvailableUsage": "無可用額度", "modelProvider.card.onTrial": "試用中", "modelProvider.card.paid": "已購買", "modelProvider.card.priorityUse": "優先使用", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "刪除 API 金鑰", "modelProvider.card.tip": "消息額度支持使用 {{modelNames}} 的模型;免費額度會在付費額度用盡後才會消耗。", "modelProvider.card.tokens": "Tokens", + "modelProvider.card.unavailable": "不可用", + "modelProvider.card.upgradePlan": "升級方案", + "modelProvider.card.usageLabel": "用量", + "modelProvider.card.usagePriority": "使用優先順序", + "modelProvider.card.usagePriorityTip": "設定執行模型時優先使用的資源。", "modelProvider.collapse": "收起", "modelProvider.config": "配置", "modelProvider.configLoadBalancing": "配置負載均衡", @@ -387,9 +406,11 @@ "modelProvider.model": "模型", "modelProvider.modelAndParameters": "模型及引數", "modelProvider.modelHasBeenDeprecated": "此模型已棄用", + "modelProvider.modelSettings": "模型設定", "modelProvider.models": "模型列表", "modelProvider.modelsNum": "{{num}} 個模型", "modelProvider.noModelFound": "找不到模型 {{model}}", + "modelProvider.noneConfigured": "配置預設系統模型以執行應用", "modelProvider.notConfigured": "系統模型尚未完全配置,部分功能可能無法使用。", "modelProvider.parameters": "引數", "modelProvider.parametersInvalidRemoved": "一些參數無效,已被移除", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "於 {{date}} 重置", "modelProvider.searchModel": "搜尋模型", "modelProvider.selectModel": "選擇您的模型", + "modelProvider.selector.aiCredits": "AI 額度", + "modelProvider.selector.apiKeyUnavailable": "API Key 不可用", + "modelProvider.selector.apiKeyUnavailableTip": "API Key 已被移除,請重新配置 API Key。", + "modelProvider.selector.configure": "配置", + "modelProvider.selector.configureRequired": "需要配置", + "modelProvider.selector.creditsExhausted": "額度已用盡", + "modelProvider.selector.creditsExhaustedTip": "AI 額度已用盡,請升級方案或新增 API Key。", + "modelProvider.selector.disabled": "已停用", + "modelProvider.selector.discoverMoreInMarketplace": "在插件市場探索更多", "modelProvider.selector.emptySetting": "請前往設定進行配置", "modelProvider.selector.emptyTip": "無可用模型", + "modelProvider.selector.fromMarketplace": "從插件市場安裝", + "modelProvider.selector.incompatible": "不相容", + "modelProvider.selector.incompatibleTip": "此模型在目前版本中不可用,請選擇其他可用模型。", + "modelProvider.selector.install": "安裝", + "modelProvider.selector.modelProviderSettings": "模型供應商設定", + "modelProvider.selector.noProviderConfigured": "未配置模型供應商", + "modelProvider.selector.noProviderConfiguredDesc": "前往插件市場安裝,或在設定中配置供應商。", + "modelProvider.selector.onlyCompatibleModelsShown": "僅顯示相容的模型", "modelProvider.selector.rerankTip": "請設定 Rerank 模型", "modelProvider.selector.tip": "該模型已被刪除。請添模型或選擇其他模型。", "modelProvider.setupModelFirst": "請先設定您的模型", diff --git a/web/i18n/zh-Hant/plugin.json b/web/i18n/zh-Hant/plugin.json index 20630f41ab..4e7d577aaa 100644 --- a/web/i18n/zh-Hant/plugin.json +++ b/web/i18n/zh-Hant/plugin.json @@ -3,6 +3,7 @@ "action.delete": "刪除插件", "action.deleteContentLeft": "是否要刪除", "action.deleteContentRight": "插件?", + "action.deleteSuccess": "插件移除成功", "action.pluginInfo": "插件資訊", "action.usedInApps": "此插件正在 {{num}} 個應用程式中使用。", "allCategories": "全部分類", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "安裝", "detailPanel.operation.remove": "刪除", "detailPanel.operation.update": "更新", + "detailPanel.operation.updateTooltip": "更新以取得最新模型。", "detailPanel.operation.viewDetail": "查看詳情", "detailPanel.serviceOk": "服務正常", "detailPanel.strategyNum": "{{num}} {{strategy}} 包括", @@ -231,12 +233,18 @@ "source.local": "本地包檔", "source.marketplace": "市場", "task.clearAll": "全部清除", + "task.errorMsg.github": "此插件無法自動安裝。\n請從 GitHub 安裝。", + "task.errorMsg.marketplace": "此插件無法自動安裝。\n請從插件市場安裝。", + "task.errorMsg.unknown": "此插件無法安裝。\n無法識別插件來源。", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} 個插件安裝失敗,點擊查看", + "task.installFromGithub": "從 GitHub 安裝", + "task.installFromMarketplace": "從插件市場安裝", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} 個插件安裝失敗", "task.installing": "正在安裝插件。", + "task.installingHint": "正在安裝……可能需要幾分鐘。", "task.installingWithError": "安裝 {{installingLength}} 個插件,{{successLength}} 成功,{{errorLength}} 失敗", "task.installingWithSuccess": "安裝 {{installingLength}} 個插件,{{successLength}} 成功。", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/zh-Hant/workflow.json b/web/i18n/zh-Hant/workflow.json index 8a3c703f55..b739984977 100644 --- a/web/i18n/zh-Hant/workflow.json +++ b/web/i18n/zh-Hant/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "更新工作流程", "error.startNodeRequired": "請先新增一個起始節點,再執行 {{operation}}", "errorMsg.authRequired": "請先授權", + "errorMsg.configureModel": "請配置模型", "errorMsg.fieldRequired": "{{field}} 不能為空", "errorMsg.fields.code": "程式碼", "errorMsg.fields.model": "模型", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Vision Variable", "errorMsg.invalidJson": "{{field}} 是非法的 JSON", "errorMsg.invalidVariable": "無效的變數", + "errorMsg.modelPluginNotInstalled": "無效的變數。請配置模型以啟用此變數。", "errorMsg.noValidTool": "{{field}} 未選擇有效工具", "errorMsg.rerankModelRequired": "在開啟 Rerank 模型之前,請在設置中確認模型配置成功。", "errorMsg.startNodeRequired": "請先新增一個起始節點,再執行 {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "記憶窗口", "nodes.common.outputVars": "輸出變數", "nodes.common.pluginNotInstalled": "插件未安裝", + "nodes.common.pluginsNotInstalled": "{{count}} 個插件未安裝", "nodes.common.retry.maxRetries": "最大重試次數", "nodes.common.retry.ms": "毫秒", "nodes.common.retry.retries": "{{num}}重試", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "區塊", "nodes.knowledgeBase.chunksInputTip": "知識庫節點的輸入變數是 Chunks。該變數類型是一個物件,具有特定的 JSON Schema,必須與所選的塊結構一致。", "nodes.knowledgeBase.chunksVariableIsRequired": "Chunks 變數是必需的", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key 不可用", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "額度已用盡", + "nodes.knowledgeBase.embeddingModelIncompatible": "不相容", "nodes.knowledgeBase.embeddingModelIsInvalid": "嵌入模型無效", "nodes.knowledgeBase.embeddingModelIsRequired": "需要嵌入模型", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Embedding 模型未配置", "nodes.knowledgeBase.indexMethodIsRequired": "索引方法是必填的", + "nodes.knowledgeBase.notConfigured": "未配置", "nodes.knowledgeBase.rerankingModelIsInvalid": "重排序模型無效", "nodes.knowledgeBase.rerankingModelIsRequired": "需要重新排序模型", "nodes.knowledgeBase.retrievalSettingIsRequired": "需要檢索設定", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "只支持 Jinja2", "nodes.templateTransform.inputVars": "輸入變數", "nodes.templateTransform.outputVars.output": "轉換後內容", + "nodes.tool.authorizationRequired": "需要授權", "nodes.tool.authorize": "授權", "nodes.tool.inputVars": "輸入變數", "nodes.tool.insertPlaceholder1": "輸入或按壓", @@ -1062,10 +1071,12 @@ "panel.change": "更改", "panel.changeBlock": "更改節點", "panel.checklist": "檢查清單", + "panel.checklistDescription": "發佈前請解決以下問題", "panel.checklistResolved": "所有問題均已解決", "panel.checklistTip": "發佈前確保所有問題均已解決", "panel.createdBy": "作者", "panel.goTo": "前往", + "panel.goToFix": "前往修復", "panel.helpLink": "查看幫助文件", "panel.maximize": "最大化畫布", "panel.minimize": "退出全螢幕", diff --git a/web/package.json b/web/package.json index 361f8e2e0e..65372ef5f5 100644 --- a/web/package.json +++ b/web/package.json @@ -220,11 +220,10 @@ "eslint-plugin-react-refresh": "0.5.2", "eslint-plugin-sonarjs": "4.0.2", "eslint-plugin-storybook": "10.3.1", + "happy-dom": "20.8.9", "hono": "4.12.8", "husky": "9.1.7", "iconify-import-svg": "0.1.2", - "jsdom": "29.0.1", - "jsdom-testing-mocks": "1.16.0", "knip": "6.0.2", "lint-staged": "16.4.0", "postcss": "8.5.8", diff --git a/web/plugins/dev-proxy/server.spec.ts b/web/plugins/dev-proxy/server.spec.ts index 9c950abae0..c57ec8b4fe 100644 --- a/web/plugins/dev-proxy/server.spec.ts +++ b/web/plugins/dev-proxy/server.spec.ts @@ -1,3 +1,6 @@ +/** + * @vitest-environment node + */ import { beforeEach, describe, expect, it, vi } from 'vitest' import { buildUpstreamUrl, createDevProxyApp, isAllowedDevOrigin, resolveDevProxyTargets } from './server' diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 191826d80d..cd1a8a8556 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -371,7 +371,7 @@ importers: devDependencies: '@antfu/eslint-config': specifier: 7.7.3 - version: 7.7.3(@eslint-react/eslint-plugin@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.2.1)(@typescript-eslint/rule-tester@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.1(typescript@5.9.3))(@typescript-eslint/utils@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(@vue/compiler-sfc@3.5.30)(eslint-plugin-react-hooks@7.0.1(eslint@10.1.0(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.1.0(jiti@1.21.7)))(eslint@10.1.0(jiti@1.21.7))(oxlint@1.56.0(oxlint-tsgolint@0.17.1))(typescript@5.9.3) + version: 7.7.3(@eslint-react/eslint-plugin@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.2.1)(@typescript-eslint/rule-tester@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.1(typescript@5.9.3))(@typescript-eslint/utils@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(@vue/compiler-sfc@3.5.30)(eslint-plugin-react-hooks@7.0.1(eslint@10.1.0(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.1.0(jiti@1.21.7)))(eslint@10.1.0(jiti@1.21.7))(oxlint@1.56.0(oxlint-tsgolint@0.17.1))(typescript@5.9.3) '@chromatic-com/storybook': specifier: 5.0.2 version: 5.0.2(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) @@ -506,7 +506,7 @@ importers: version: 0.5.21(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4) '@vitest/coverage-v8': specifier: 4.1.0 - version: 4.1.0(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) + version: 4.1.0(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) agentation: specifier: 2.3.3 version: 2.3.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4) @@ -546,6 +546,9 @@ importers: eslint-plugin-storybook: specifier: 10.3.1 version: 10.3.1(eslint@10.1.0(jiti@1.21.7))(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) + happy-dom: + specifier: 20.8.9 + version: 20.8.9 hono: specifier: 4.12.8 version: 4.12.8 @@ -555,12 +558,6 @@ importers: iconify-import-svg: specifier: 0.1.2 version: 0.1.2 - jsdom: - specifier: 29.0.1 - version: 29.0.1(canvas@3.2.2) - jsdom-testing-mocks: - specifier: 1.16.0 - version: 1.16.0 knip: specifier: 6.0.2 version: 6.0.2 @@ -608,13 +605,13 @@ importers: version: 11.3.3(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) vite-plus: specifier: 0.1.13 - version: 0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) + version: 0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) vitest: specifier: npm:@voidzero-dev/vite-plus-test@0.1.13 - version: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' + version: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' vitest-canvas-mock: specifier: 1.1.3 - version: 1.1.3(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) + version: 1.1.3(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) packages: @@ -771,12 +768,12 @@ packages: '@antfu/utils@8.1.1': resolution: {integrity: sha512-Mex9nXf9vR6AhcXmMrlz/HVgYYZpVGJ6YlPgwl7UnaFpnshXs6EK/oa5Gpf3CzENMjkvEx2tQtntGnb7UtSTOQ==} - '@asamuzakjp/css-color@5.0.1': - resolution: {integrity: sha512-2SZFvqMyvboVV1d15lMf7XiI3m7SDqXUuKaTymJYLN6dSGadqp+fVojqJlVoMlbZnlTmu3S0TLwLTJpvBMO1Aw==} + '@asamuzakjp/css-color@5.1.1': + resolution: {integrity: sha512-iGWN8E45Ws0XWx3D44Q1t6vX2LqhCKcwfmwBYCDsFrYFS6m4q/Ks61L2veETaLv+ckDC6+dTETJoaAAb7VjLiw==} engines: {node: ^20.19.0 || ^22.12.0 || >=24.0.0} - '@asamuzakjp/dom-selector@7.0.3': - resolution: {integrity: sha512-Q6mU0Z6bfj6YvnX2k9n0JxiIwrCFN59x/nWmYQnAqP000ruX/yV+5bp/GRcF5T8ncvfwJQ7fgfP74DlpKExILA==} + '@asamuzakjp/dom-selector@7.0.4': + resolution: {integrity: sha512-jXR6x4AcT3eIrS2fSNAwJpwirOkGcd+E7F7CP3zjdTqz9B/2huHOL8YJZBgekKwLML+u7qB/6P1LXQuMScsx0w==} engines: {node: ^20.19.0 || ^22.12.0 || >=24.0.0} '@asamuzakjp/nwsapi@2.3.9': @@ -960,8 +957,8 @@ packages: peerDependencies: '@csstools/css-tokenizer': ^4.0.0 - '@csstools/css-syntax-patches-for-csstree@1.1.1': - resolution: {integrity: sha512-BvqN0AMWNAnLk9G8jnUT77D+mUbY/H2b3uDTvg2isJkHaOufUE2R3AOwxWo7VBQKT1lOdwdvorddo2B/lk64+w==} + '@csstools/css-syntax-patches-for-csstree@1.1.2': + resolution: {integrity: sha512-5GkLzz4prTIpoyeUiIu3iV6CSG3Plo7xRVOFPKI7FVEJ3mZ0A8SwK0XU3Gl7xAkiQ+mDyam+NNp875/C5y+jSA==} peerDependencies: css-tree: ^3.2.1 peerDependenciesMeta: @@ -3510,6 +3507,12 @@ packages: '@types/unist@3.0.3': resolution: {integrity: sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==} + '@types/whatwg-mimetype@3.0.2': + resolution: {integrity: sha512-c2AKvDT8ToxLIOUlN51gTiHXflsfIFisS4pO7pDPoKouJCESkhZnEy623gwP9laCy5lnLDAw1vAzu2vM2YLOrA==} + + '@types/ws@8.18.1': + resolution: {integrity: sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==} + '@types/yauzl@2.10.3': resolution: {integrity: sha512-oJoftv0LSuaDZE3Le4DbKX+KS9G36NzOeSap90UIK0yMA/NhKJhqlSGtNDORNRaIbQfzjXDrQa0ytJ6mNRGz/Q==} @@ -4082,9 +4085,6 @@ packages: engines: {node: '>=6.0.0'} hasBin: true - bezier-easing@2.1.0: - resolution: {integrity: sha512-gbIqZ/eslnUFC1tjEvtz0sgx+xTK20wDnYMIA27VA04R7w6xxXQPZDbibjA9DTWZRA2CXtwHykkVzlCaAJAZig==} - bidi-js@1.0.3: resolution: {integrity: sha512-RKshQI1R3YQ+n9YJz2QQ147P66ELpa1FQEg20Dk8oW9t2KgLbpDLLp9aGZ7y8WHSshDknG0bknqGw5/tyCs5tw==} @@ -4385,9 +4385,6 @@ packages: resolution: {integrity: sha512-3O5QdqgFRUbXvK1x5INf1YkBz1UKSWqrd63vWsum8MNHDBYD5urm3QtxZbKU259OrEXNM26lP/MPY3d1IGkBgA==} engines: {node: '>=16'} - css-mediaquery@0.1.2: - resolution: {integrity: sha512-COtn4EROW5dBGlE/4PiKnh6rZpAPxDeFLaEEwt4i10jpDMFt2EhQGS79QmmrO+iKCHv0PU/HrOWEhijFd1x99Q==} - css-select@5.2.2: resolution: {integrity: sha512-TizTzUddG/xYLA3NXodFM0fSbNizXjOKhqiQQwvhlspadZokn1KDy0NZFS0wuEubIYAV5/c1/lAr0TaaFXEXzw==} @@ -5359,6 +5356,10 @@ packages: hachure-fill@0.5.2: resolution: {integrity: sha512-3GKBOn+m2LX9iq+JC1064cSFprJY4jL1jCXTcpnfER5HYE2l/4EfWSGzkPa/ZDBmYI0ZOEj5VHV/eKnPGkHuOg==} + happy-dom@20.8.9: + resolution: {integrity: sha512-Tz23LR9T9jOGVZm2x1EPdXqwA37G/owYMxRwU0E4miurAtFsPMQ1d2Jc2okUaSjZqAFz2oEn3FLXC5a0a+siyA==} + engines: {node: '>=20.0.0'} + has-flag@4.0.0: resolution: {integrity: sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==} engines: {node: '>=8'} @@ -5690,10 +5691,6 @@ packages: resolution: {integrity: sha512-/2uqY7x6bsrpi3i9LVU6J89352C0rpMk0as8trXxCtvd4kPk1ke/Eyif6wqfSLvoNJqcDG9Vk4UsXgygzCt2xA==} engines: {node: '>=20.0.0'} - jsdom-testing-mocks@1.16.0: - resolution: {integrity: sha512-wLrulXiLpjmcUYOYGEvz4XARkrmdVpyxzdBl9IAMbQ+ib2/UhUTRCn49McdNfXLff2ysGBUms49ZKX0LR1Q0gg==} - engines: {node: '>=14'} - jsdom@29.0.1: resolution: {integrity: sha512-z6JOK5gRO7aMybVq/y/MlIpKh8JIi68FBKMUtKkK2KH/wMSRlCxQ682d08LB9fYXplyY/UXG8P4XXTScmdjApg==} engines: {node: ^20.19.0 || ^22.13.0 || >=24.0.0} @@ -7316,6 +7313,10 @@ packages: resolution: {integrity: sha512-g9ljZiwki/LfxmQADO3dEY1CbpmXT5Hm2fJ+QaGKwSXUylMybePR7/67YW7jOrrvjEgL1Fmz5kzyAjWVWLlucg==} engines: {node: '>=6'} + tapable@2.3.2: + resolution: {integrity: sha512-1MOpMXuhGzGL5TTCZFItxCc0AARf1EZFQkGqMm7ERKj8+Hgr5oLvJOVFcC+lRmR8hCe2S3jC4T5D7Vg/d7/fhA==} + engines: {node: '>=6'} + tar-fs@2.1.4: resolution: {integrity: sha512-mDAjwmZdh7LTT6pNleZ05Yt65HC3E+NiQzl672vQG38jIrehtJk/J3mNwIg+vShQPcLF/LV7CMnDW6vjj6sfYQ==} @@ -7535,8 +7536,8 @@ packages: resolution: {integrity: sha512-jxytwMHhsbdpBXxLAcuu0fzlQeXCNnWdDyRHpvWsUl8vd98UwYdl9YTyn8/HcpcJPC3pwUveefsa3zTxyD/ERg==} engines: {node: '>=20.18.1'} - undici@7.24.5: - resolution: {integrity: sha512-3IWdCpjgxp15CbJnsi/Y9TCDE7HWVN19j1hmzVhoAkY/+CJx449tVxT5wZc1Gwg8J+P0LWvzlBzxYRnHJ+1i7Q==} + undici@7.24.6: + resolution: {integrity: sha512-Xi4agocCbRzt0yYMZGMA6ApD7gvtUFaxm4ZmeacWI4cZxaF6C+8I8QfofC20NAePiB/IcvZmzkJ7XPa471AEtA==} engines: {node: '>=20.18.1'} unicode-trie@2.0.0: @@ -7683,7 +7684,7 @@ packages: resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==} vinext@https://pkg.pr.new/vinext@b6a2cac: - resolution: {integrity: sha512-/Jm507qqC1dCOhCaorb9H8/I5JEqkcsiUJw0Wgprg7Znym4eyLUvcWcRLVyM9z22Tm0+O1PugcSDA8oNvbqPuQ==, tarball: https://pkg.pr.new/vinext@b6a2cac} + resolution: {tarball: https://pkg.pr.new/vinext@b6a2cac} version: 0.0.5 engines: {node: '>=22'} hasBin: true @@ -7841,6 +7842,10 @@ packages: engines: {node: '>=18'} deprecated: Use @exodus/bytes instead for a more spec-conformant and faster implementation + whatwg-mimetype@3.0.0: + resolution: {integrity: sha512-nt+N2dzIutVRxARx1nghPKGv1xHikU7HKdfafKkLNLindmPU/ch3U31NOCGGA/dmPcmb1VlofO0vnKAcsm0o/Q==} + engines: {node: '>=12'} + whatwg-mimetype@4.0.0: resolution: {integrity: sha512-QaKxh0eNIi2mE9p2vEdzfagOKHCcj1pJ56EEHGQOVxp8r9/iszLUUV7v89x9O1p/T+NlTM5W7jW6+cz4Fq1YVg==} engines: {node: '>=18'} @@ -7881,6 +7886,18 @@ packages: utf-8-validate: optional: true + ws@8.20.0: + resolution: {integrity: sha512-sAt8BhgNbzCtgGbt2OxmpuryO63ZoDk/sqaB/znQm94T4fCEsy/yV+7CdC1kJhOU9lboAEU7R3kquuycDoibVA==} + engines: {node: '>=10.0.0'} + peerDependencies: + bufferutil: ^4.0.1 + utf-8-validate: '>=5.0.2' + peerDependenciesMeta: + bufferutil: + optional: true + utf-8-validate: + optional: true + wsl-utils@0.1.0: resolution: {integrity: sha512-h3Fbisa2nKGPxCpm89Hk33lBLsnaGBvctQopaBSOW/uIs6FTe1ATyAnKFJrzVs9vpGdsTe73WF3V4lIsk4Gacw==} engines: {node: '>=18'} @@ -8140,7 +8157,7 @@ snapshots: idb: 8.0.0 tslib: 2.8.1 - '@antfu/eslint-config@7.7.3(@eslint-react/eslint-plugin@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.2.1)(@typescript-eslint/rule-tester@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.1(typescript@5.9.3))(@typescript-eslint/utils@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(@vue/compiler-sfc@3.5.30)(eslint-plugin-react-hooks@7.0.1(eslint@10.1.0(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.1.0(jiti@1.21.7)))(eslint@10.1.0(jiti@1.21.7))(oxlint@1.56.0(oxlint-tsgolint@0.17.1))(typescript@5.9.3)': + '@antfu/eslint-config@7.7.3(@eslint-react/eslint-plugin@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.2.1)(@typescript-eslint/rule-tester@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.1(typescript@5.9.3))(@typescript-eslint/utils@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(@vue/compiler-sfc@3.5.30)(eslint-plugin-react-hooks@7.0.1(eslint@10.1.0(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.1.0(jiti@1.21.7)))(eslint@10.1.0(jiti@1.21.7))(oxlint@1.56.0(oxlint-tsgolint@0.17.1))(typescript@5.9.3)': dependencies: '@antfu/install-pkg': 1.1.0 '@clack/prompts': 1.1.0 @@ -8150,7 +8167,7 @@ snapshots: '@stylistic/eslint-plugin': 5.10.0(eslint@10.1.0(jiti@1.21.7)) '@typescript-eslint/eslint-plugin': 8.57.1(@typescript-eslint/parser@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) '@typescript-eslint/parser': 8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) - '@vitest/eslint-plugin': 1.6.12(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@vitest/eslint-plugin': 1.6.12(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) ansis: 4.2.0 cac: 7.0.0 eslint: 10.1.0(jiti@1.21.7) @@ -8210,23 +8227,26 @@ snapshots: '@antfu/utils@8.1.1': {} - '@asamuzakjp/css-color@5.0.1': + '@asamuzakjp/css-color@5.1.1': dependencies: '@csstools/css-calc': 3.1.1(@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0))(@csstools/css-tokenizer@4.0.0) '@csstools/css-color-parser': 4.0.2(@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0))(@csstools/css-tokenizer@4.0.0) '@csstools/css-parser-algorithms': 4.0.0(@csstools/css-tokenizer@4.0.0) '@csstools/css-tokenizer': 4.0.0 lru-cache: 11.2.7 + optional: true - '@asamuzakjp/dom-selector@7.0.3': + '@asamuzakjp/dom-selector@7.0.4': dependencies: '@asamuzakjp/nwsapi': 2.3.9 bidi-js: 1.0.3 css-tree: 3.2.1 is-potential-custom-element-name: 1.0.1 lru-cache: 11.2.7 + optional: true - '@asamuzakjp/nwsapi@2.3.9': {} + '@asamuzakjp/nwsapi@2.3.9': + optional: true '@babel/code-frame@7.29.0': dependencies: @@ -8361,6 +8381,7 @@ snapshots: '@bramus/specificity@2.4.2': dependencies: css-tree: 3.2.1 + optional: true '@chevrotain/cst-dts-gen@11.1.2': dependencies: @@ -8453,12 +8474,14 @@ snapshots: transitivePeerDependencies: - supports-color - '@csstools/color-helpers@6.0.2': {} + '@csstools/color-helpers@6.0.2': + optional: true '@csstools/css-calc@3.1.1(@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0))(@csstools/css-tokenizer@4.0.0)': dependencies: '@csstools/css-parser-algorithms': 4.0.0(@csstools/css-tokenizer@4.0.0) '@csstools/css-tokenizer': 4.0.0 + optional: true '@csstools/css-color-parser@4.0.2(@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0))(@csstools/css-tokenizer@4.0.0)': dependencies: @@ -8466,16 +8489,20 @@ snapshots: '@csstools/css-calc': 3.1.1(@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0))(@csstools/css-tokenizer@4.0.0) '@csstools/css-parser-algorithms': 4.0.0(@csstools/css-tokenizer@4.0.0) '@csstools/css-tokenizer': 4.0.0 + optional: true '@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0)': dependencies: '@csstools/css-tokenizer': 4.0.0 + optional: true - '@csstools/css-syntax-patches-for-csstree@1.1.1(css-tree@3.2.1)': + '@csstools/css-syntax-patches-for-csstree@1.1.2(css-tree@3.2.1)': optionalDependencies: css-tree: 3.2.1 + optional: true - '@csstools/css-tokenizer@4.0.0': {} + '@csstools/css-tokenizer@4.0.0': + optional: true '@e18e/eslint-plugin@0.2.0(eslint@10.1.0(jiti@1.21.7))(oxlint@1.56.0(oxlint-tsgolint@0.17.1))': dependencies: @@ -8777,7 +8804,8 @@ snapshots: '@eslint/core': 1.1.1 levn: 0.4.1 - '@exodus/bytes@1.15.0': {} + '@exodus/bytes@1.15.0': + optional: true '@floating-ui/core@1.7.5': dependencies: @@ -10356,7 +10384,7 @@ snapshots: '@tanstack/devtools-event-bus@0.4.1': dependencies: - ws: 8.19.0 + ws: 8.20.0 transitivePeerDependencies: - bufferutil - utf-8-validate @@ -10814,6 +10842,12 @@ snapshots: '@types/unist@3.0.3': {} + '@types/whatwg-mimetype@3.0.2': {} + + '@types/ws@8.18.1': + dependencies: + '@types/node': 25.5.0 + '@types/yauzl@2.10.3': dependencies: '@types/node': 25.5.0 @@ -11019,7 +11053,7 @@ snapshots: optionalDependencies: react-server-dom-webpack: 19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) - '@vitest/coverage-v8@4.1.0(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))': + '@vitest/coverage-v8@4.1.0(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))': dependencies: '@bcoe/v8-coverage': 1.0.2 '@vitest/utils': 4.1.0 @@ -11031,16 +11065,16 @@ snapshots: obug: 2.1.1 std-env: 4.0.0 tinyrainbow: 3.1.0 - vitest: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' + vitest: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' - '@vitest/eslint-plugin@1.6.12(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': + '@vitest/eslint-plugin@1.6.12(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': dependencies: '@typescript-eslint/scope-manager': 8.57.1 '@typescript-eslint/utils': 8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) eslint: 10.1.0(jiti@1.21.7) optionalDependencies: typescript: 5.9.3 - vitest: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' + vitest: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' transitivePeerDependencies: - supports-color @@ -11105,7 +11139,7 @@ snapshots: '@voidzero-dev/vite-plus-linux-x64-gnu@0.1.13': optional: true - '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)': + '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)': dependencies: '@standard-schema/spec': 1.1.0 '@types/chai': 5.2.3 @@ -11123,6 +11157,7 @@ snapshots: ws: 8.19.0 optionalDependencies: '@types/node': 25.5.0 + happy-dom: 20.8.9 jsdom: 29.0.1(canvas@3.2.2) transitivePeerDependencies: - '@arethetypeswrong/core' @@ -11419,11 +11454,10 @@ snapshots: baseline-browser-mapping@2.10.8: {} - bezier-easing@2.1.0: {} - bidi-js@1.0.3: dependencies: require-from-string: 2.0.2 + optional: true binary-extensions@2.3.0: {} @@ -11715,8 +11749,6 @@ snapshots: css-gradient-parser@0.0.16: {} - css-mediaquery@0.1.2: {} - css-select@5.2.2: dependencies: boolbase: 1.0.0 @@ -11745,6 +11777,7 @@ snapshots: dependencies: mdn-data: 2.27.1 source-map-js: 1.2.1 + optional: true css-what@6.2.2: {} @@ -11950,6 +11983,7 @@ snapshots: whatwg-url: 16.0.1 transitivePeerDependencies: - '@noble/hashes' + optional: true dayjs@1.11.20: {} @@ -12897,6 +12931,18 @@ snapshots: hachure-fill@0.5.2: {} + happy-dom@20.8.9: + dependencies: + '@types/node': 25.5.0 + '@types/whatwg-mimetype': 3.0.2 + '@types/ws': 8.18.1 + entities: 7.0.1 + whatwg-mimetype: 3.0.0 + ws: 8.20.0 + transitivePeerDependencies: + - bufferutil + - utf-8-validate + has-flag@4.0.0: {} hast-util-from-dom@5.0.1: @@ -13061,6 +13107,7 @@ snapshots: '@exodus/bytes': 1.15.0 transitivePeerDependencies: - '@noble/hashes' + optional: true html-entities@2.6.0: {} @@ -13199,7 +13246,8 @@ snapshots: is-plain-obj@4.1.0: {} - is-potential-custom-element-name@1.0.1: {} + is-potential-custom-element-name@1.0.1: + optional: true is-reference@3.0.3: dependencies: @@ -13261,17 +13309,12 @@ snapshots: jsdoc-type-pratt-parser@7.1.1: {} - jsdom-testing-mocks@1.16.0: - dependencies: - bezier-easing: 2.1.0 - css-mediaquery: 0.1.2 - jsdom@29.0.1(canvas@3.2.2): dependencies: - '@asamuzakjp/css-color': 5.0.1 - '@asamuzakjp/dom-selector': 7.0.3 + '@asamuzakjp/css-color': 5.1.1 + '@asamuzakjp/dom-selector': 7.0.4 '@bramus/specificity': 2.4.2 - '@csstools/css-syntax-patches-for-csstree': 1.1.1(css-tree@3.2.1) + '@csstools/css-syntax-patches-for-csstree': 1.1.2(css-tree@3.2.1) '@exodus/bytes': 1.15.0 css-tree: 3.2.1 data-urls: 7.0.0 @@ -13283,7 +13326,7 @@ snapshots: saxes: 6.0.0 symbol-tree: 3.2.4 tough-cookie: 6.0.1 - undici: 7.24.5 + undici: 7.24.6 w3c-xmlserializer: 5.0.0 webidl-conversions: 8.0.1 whatwg-mimetype: 5.0.0 @@ -13293,6 +13336,7 @@ snapshots: canvas: 3.2.2 transitivePeerDependencies: - '@noble/hashes' + optional: true jsesc@3.1.0: {} @@ -13753,7 +13797,8 @@ snapshots: mdn-data@2.23.0: {} - mdn-data@2.27.1: {} + mdn-data@2.27.1: + optional: true memoize-one@5.2.1: {} @@ -15121,6 +15166,7 @@ snapshots: saxes@6.0.0: dependencies: xmlchars: 2.2.0 + optional: true scheduler@0.27.0: {} @@ -15397,7 +15443,8 @@ snapshots: picocolors: 1.1.1 sax: 1.6.0 - symbol-tree@3.2.4: {} + symbol-tree@3.2.4: + optional: true synckit@0.11.12: dependencies: @@ -15443,6 +15490,8 @@ snapshots: tapable@2.3.0: {} + tapable@2.3.2: {} + tar-fs@2.1.4: dependencies: chownr: 1.1.4 @@ -15559,10 +15608,12 @@ snapshots: tough-cookie@6.0.1: dependencies: tldts: 7.0.27 + optional: true tr46@6.0.0: dependencies: punycode: 2.3.1 + optional: true trim-lines@3.0.1: {} @@ -15655,7 +15706,8 @@ snapshots: undici@7.24.0: {} - undici@7.24.5: {} + undici@7.24.6: + optional: true unicode-trie@2.0.0: dependencies: @@ -15879,11 +15931,11 @@ snapshots: - supports-color - typescript - vite-plus@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3): + vite-plus@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3): dependencies: '@oxc-project/types': 0.120.0 '@voidzero-dev/vite-plus-core': 0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) - '@voidzero-dev/vite-plus-test': 0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) + '@voidzero-dev/vite-plus-test': 0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) cac: 7.0.0 cross-spawn: 7.0.6 oxfmt: 0.41.0 @@ -15950,11 +16002,11 @@ snapshots: optionalDependencies: vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' - vitest-canvas-mock@1.1.3(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)): + vitest-canvas-mock@1.1.3(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)): dependencies: cssfontparser: 1.2.1 moo-color: 1.0.3 - vitest: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' + vitest: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' void-elements@3.1.0: {} @@ -15990,6 +16042,7 @@ snapshots: w3c-xmlserializer@5.0.0: dependencies: xml-name-validator: 5.0.0 + optional: true walk-up-path@4.0.0: {} @@ -16002,7 +16055,8 @@ snapshots: web-vitals@5.1.0: {} - webidl-conversions@8.0.1: {} + webidl-conversions@8.0.1: + optional: true webpack-sources@3.3.4: {} @@ -16031,7 +16085,7 @@ snapshots: mime-types: 2.1.35 neo-async: 2.6.2 schema-utils: 4.3.3 - tapable: 2.3.0 + tapable: 2.3.2 terser-webpack-plugin: 5.4.0(esbuild@0.27.2)(uglify-js@3.19.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) watchpack: 2.5.1 webpack-sources: 3.3.4 @@ -16044,9 +16098,12 @@ snapshots: dependencies: iconv-lite: 0.6.3 + whatwg-mimetype@3.0.0: {} + whatwg-mimetype@4.0.0: {} - whatwg-mimetype@5.0.0: {} + whatwg-mimetype@5.0.0: + optional: true whatwg-url@16.0.1: dependencies: @@ -16055,6 +16112,7 @@ snapshots: webidl-conversions: 8.0.1 transitivePeerDependencies: - '@noble/hashes' + optional: true which@2.0.2: dependencies: @@ -16072,15 +16130,19 @@ snapshots: ws@8.19.0: {} + ws@8.20.0: {} + wsl-utils@0.1.0: dependencies: is-wsl: 3.1.1 xml-name-validator@4.0.0: {} - xml-name-validator@5.0.0: {} + xml-name-validator@5.0.0: + optional: true - xmlchars@2.2.0: {} + xmlchars@2.2.0: + optional: true xtend@4.0.2: {} diff --git a/web/vite.config.ts b/web/vite.config.ts index 617cae9ab5..28746f81ca 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -75,7 +75,8 @@ export default defineConfig(({ mode }) => { // Vitest config test: { - environment: 'jsdom', + pool: 'threads', + environment: 'happy-dom', globals: true, setupFiles: ['./vitest.setup.ts'], coverage: { diff --git a/web/vitest.setup.ts b/web/vitest.setup.ts index e63ea2b54e..ac26ac5d25 100644 --- a/web/vitest.setup.ts +++ b/web/vitest.setup.ts @@ -1,14 +1,8 @@ import { act, cleanup } from '@testing-library/react' -import { mockAnimationsApi, mockResizeObserver } from 'jsdom-testing-mocks' import * as React from 'react' import '@testing-library/jest-dom/vitest' import 'vitest-canvas-mock' -mockResizeObserver() - -// Mock Web Animations API for Headless UI -mockAnimationsApi() - // Suppress act() warnings from @headlessui/react internal Transition component // These warnings are caused by Headless UI's internal async state updates, not our code const originalConsoleError = console.error @@ -77,24 +71,10 @@ if (typeof globalThis.IntersectionObserver === 'undefined') { } } -// Mock Element.scrollIntoView for tests (not available in happy-dom/jsdom) -if (typeof Element !== 'undefined' && !Element.prototype.scrollIntoView) - Element.prototype.scrollIntoView = function () { /* noop */ } - -// Mock DOMRect.fromRect for tests (not available in jsdom) -if (typeof DOMRect !== 'undefined' && typeof (DOMRect as typeof DOMRect & { fromRect?: unknown }).fromRect !== 'function') { - (DOMRect as typeof DOMRect & { fromRect: (rect?: DOMRectInit) => DOMRect }).fromRect = (rect = {}) => new DOMRect( - rect.x ?? 0, - rect.y ?? 0, - rect.width ?? 0, - rect.height ?? 0, - ) -} - afterEach(async () => { // Wrap cleanup in act() to flush pending React scheduler work // This prevents "window is not defined" errors from React 19's scheduler - // which uses setImmediate/MessageChannel that can fire after jsdom cleanup + // which uses setImmediate/MessageChannel that can fire after DOM cleanup await act(async () => { cleanup() }) @@ -131,19 +111,97 @@ vi.mock('@floating-ui/react', async () => { } }) -// mock window.matchMedia -Object.defineProperty(window, 'matchMedia', { - writable: true, - value: vi.fn().mockImplementation(query => ({ - matches: false, - media: query, - onchange: null, - addListener: vi.fn(), // deprecated - removeListener: vi.fn(), // deprecated - addEventListener: vi.fn(), - removeEventListener: vi.fn(), - dispatchEvent: vi.fn(), - })), +vi.mock('@monaco-editor/react', () => { + const createEditorMock = () => { + const focusListeners: Array<() => void> = [] + const blurListeners: Array<() => void> = [] + + return { + getContentHeight: vi.fn(() => 56), + onDidFocusEditorText: vi.fn((listener: () => void) => { + focusListeners.push(listener) + return { dispose: vi.fn() } + }), + onDidBlurEditorText: vi.fn((listener: () => void) => { + blurListeners.push(listener) + return { dispose: vi.fn() } + }), + layout: vi.fn(), + getAction: vi.fn(() => ({ run: vi.fn() })), + getModel: vi.fn(() => ({ + getLineContent: vi.fn(() => ''), + })), + getPosition: vi.fn(() => ({ lineNumber: 1, column: 1 })), + deltaDecorations: vi.fn(() => []), + focus: vi.fn(() => { + focusListeners.forEach(listener => listener()) + }), + setPosition: vi.fn(), + revealLine: vi.fn(), + trigger: vi.fn(), + __blur: () => { + blurListeners.forEach(listener => listener()) + }, + } + } + + const monacoMock = { + editor: { + setTheme: vi.fn(), + defineTheme: vi.fn(), + }, + Range: class { + startLineNumber: number + startColumn: number + endLineNumber: number + endColumn: number + constructor(startLineNumber: number, startColumn: number, endLineNumber: number, endColumn: number) { + this.startLineNumber = startLineNumber + this.startColumn = startColumn + this.endLineNumber = endLineNumber + this.endColumn = endColumn + } + }, + } + + const MonacoEditor = ({ + value = '', + onChange, + onMount, + options, + }: { + value?: string + onChange?: (value: string | undefined) => void + onMount?: (editor: ReturnType, monaco: typeof monacoMock) => void + options?: { readOnly?: boolean } + }) => { + const editorRef = React.useRef | null>(null) + if (!editorRef.current) + editorRef.current = createEditorMock() + + React.useEffect(() => { + onMount?.(editorRef.current!, monacoMock) + }, [onMount]) + + return React.createElement('textarea', { + 'data-testid': 'monaco-editor', + 'readOnly': options?.readOnly, + value, + 'onChange': (event: React.ChangeEvent) => onChange?.(event.target.value), + 'onFocus': () => editorRef.current?.focus(), + 'onBlur': () => editorRef.current?.__blur(), + }) + } + + return { + __esModule: true, + default: MonacoEditor, + Editor: MonacoEditor, + loader: { + config: vi.fn(), + init: vi.fn().mockResolvedValue(monacoMock), + }, + } }) // Mock localStorage for testing