mirror of
https://github.com/langgenius/dify.git
synced 2026-04-10 21:00:30 -04:00
Merge branch 'main' into jzh
This commit is contained in:
9
.github/labeler.yml
vendored
9
.github/labeler.yml
vendored
@@ -1,3 +1,10 @@
|
||||
web:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'web/**'
|
||||
- any-glob-to-any-file:
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
|
||||
82
.github/scripts/generate-i18n-changes.mjs
vendored
Normal file
82
.github/scripts/generate-i18n-changes.mjs
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
import { execFileSync } from 'node:child_process'
|
||||
import fs from 'node:fs'
|
||||
import path from 'node:path'
|
||||
|
||||
const repoRoot = process.cwd()
|
||||
const baseSha = process.env.BASE_SHA || ''
|
||||
const headSha = process.env.HEAD_SHA || ''
|
||||
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
|
||||
const outputPath = process.env.I18N_CHANGES_OUTPUT_PATH || '/tmp/i18n-changes.json'
|
||||
|
||||
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
|
||||
|
||||
const readCurrentJson = (fileStem) => {
|
||||
const filePath = englishPath(fileStem)
|
||||
if (!fs.existsSync(filePath))
|
||||
return null
|
||||
|
||||
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
|
||||
}
|
||||
|
||||
const readBaseJson = (fileStem) => {
|
||||
if (!baseSha)
|
||||
return null
|
||||
|
||||
try {
|
||||
const relativePath = `web/i18n/en-US/${fileStem}.json`
|
||||
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
|
||||
return JSON.parse(content)
|
||||
}
|
||||
catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
|
||||
|
||||
const changes = {}
|
||||
|
||||
for (const fileStem of files) {
|
||||
const currentJson = readCurrentJson(fileStem)
|
||||
const beforeJson = readBaseJson(fileStem) || {}
|
||||
const afterJson = currentJson || {}
|
||||
const added = {}
|
||||
const updated = {}
|
||||
const deleted = []
|
||||
|
||||
for (const [key, value] of Object.entries(afterJson)) {
|
||||
if (!(key in beforeJson)) {
|
||||
added[key] = value
|
||||
continue
|
||||
}
|
||||
|
||||
if (!compareJson(beforeJson[key], value)) {
|
||||
updated[key] = {
|
||||
before: beforeJson[key],
|
||||
after: value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const key of Object.keys(beforeJson)) {
|
||||
if (!(key in afterJson))
|
||||
deleted.push(key)
|
||||
}
|
||||
|
||||
changes[fileStem] = {
|
||||
fileDeleted: currentJson === null,
|
||||
added,
|
||||
updated,
|
||||
deleted,
|
||||
}
|
||||
}
|
||||
|
||||
fs.writeFileSync(
|
||||
outputPath,
|
||||
JSON.stringify({
|
||||
baseSha,
|
||||
headSha,
|
||||
files,
|
||||
changes,
|
||||
})
|
||||
)
|
||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@@ -39,9 +39,11 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
packages/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
- name: Check api inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
|
||||
2
.github/workflows/docker-build.yml
vendored
2
.github/workflows/docker-build.yml
vendored
@@ -8,9 +8,11 @@ on:
|
||||
- api/Dockerfile
|
||||
- web/docker/**
|
||||
- web/Dockerfile
|
||||
- packages/**
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
- .nvmrc
|
||||
|
||||
concurrency:
|
||||
|
||||
4
.github/workflows/main-ci.yml
vendored
4
.github/workflows/main-ci.yml
vendored
@@ -65,9 +65,11 @@ jobs:
|
||||
- 'docker/volumes/sandbox/conf/**'
|
||||
web:
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
@@ -77,9 +79,11 @@ jobs:
|
||||
- 'api/uv.lock'
|
||||
- 'e2e/**'
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/middleware.env.example'
|
||||
|
||||
2
.github/workflows/style.yml
vendored
2
.github/workflows/style.yml
vendored
@@ -77,9 +77,11 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
packages/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
1
.github/workflows/tool-test-sdks.yaml
vendored
1
.github/workflows/tool-test-sdks.yaml
vendored
@@ -9,6 +9,7 @@ on:
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
|
||||
concurrency:
|
||||
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
101
.github/workflows/translate-i18n-claude.yml
vendored
101
.github/workflows/translate-i18n-claude.yml
vendored
@@ -68,89 +68,7 @@ jobs:
|
||||
" web/i18n-config/languages.ts | sed 's/[[:space:]]*$//')
|
||||
|
||||
generate_changes_json() {
|
||||
node <<'NODE'
|
||||
const { execFileSync } = require('node:child_process')
|
||||
const fs = require('node:fs')
|
||||
const path = require('node:path')
|
||||
|
||||
const repoRoot = process.cwd()
|
||||
const baseSha = process.env.BASE_SHA || ''
|
||||
const headSha = process.env.HEAD_SHA || ''
|
||||
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
|
||||
|
||||
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
|
||||
|
||||
const readCurrentJson = (fileStem) => {
|
||||
const filePath = englishPath(fileStem)
|
||||
if (!fs.existsSync(filePath))
|
||||
return null
|
||||
|
||||
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
|
||||
}
|
||||
|
||||
const readBaseJson = (fileStem) => {
|
||||
if (!baseSha)
|
||||
return null
|
||||
|
||||
try {
|
||||
const relativePath = `web/i18n/en-US/${fileStem}.json`
|
||||
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
|
||||
return JSON.parse(content)
|
||||
}
|
||||
catch (error) {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
|
||||
|
||||
const changes = {}
|
||||
|
||||
for (const fileStem of files) {
|
||||
const currentJson = readCurrentJson(fileStem)
|
||||
const beforeJson = readBaseJson(fileStem) || {}
|
||||
const afterJson = currentJson || {}
|
||||
const added = {}
|
||||
const updated = {}
|
||||
const deleted = []
|
||||
|
||||
for (const [key, value] of Object.entries(afterJson)) {
|
||||
if (!(key in beforeJson)) {
|
||||
added[key] = value
|
||||
continue
|
||||
}
|
||||
|
||||
if (!compareJson(beforeJson[key], value)) {
|
||||
updated[key] = {
|
||||
before: beforeJson[key],
|
||||
after: value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const key of Object.keys(beforeJson)) {
|
||||
if (!(key in afterJson))
|
||||
deleted.push(key)
|
||||
}
|
||||
|
||||
changes[fileStem] = {
|
||||
fileDeleted: currentJson === null,
|
||||
added,
|
||||
updated,
|
||||
deleted,
|
||||
}
|
||||
}
|
||||
|
||||
fs.writeFileSync(
|
||||
'/tmp/i18n-changes.json',
|
||||
JSON.stringify({
|
||||
baseSha,
|
||||
headSha,
|
||||
files,
|
||||
changes,
|
||||
})
|
||||
)
|
||||
NODE
|
||||
node .github/scripts/generate-i18n-changes.mjs
|
||||
}
|
||||
|
||||
if [ "${{ github.event_name }}" = "repository_dispatch" ]; then
|
||||
@@ -270,7 +188,7 @@ jobs:
|
||||
Tool rules:
|
||||
- Use Read for repository files.
|
||||
- Use Edit for JSON updates.
|
||||
- Use Bash only for `pnpm`.
|
||||
- Use Bash only for `vp`.
|
||||
- Do not use Bash for `git`, `gh`, or branch management.
|
||||
|
||||
Required execution plan:
|
||||
@@ -292,7 +210,7 @@ jobs:
|
||||
- Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate.
|
||||
- If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth.
|
||||
4. Run a scoped pre-check before editing:
|
||||
- `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- Use this command as the source of truth for missing and extra keys inside the current scope.
|
||||
5. Apply translations.
|
||||
- For every target language and scoped file:
|
||||
@@ -300,19 +218,19 @@ jobs:
|
||||
- If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed.
|
||||
- ADD missing keys.
|
||||
- UPDATE stale translations when the English value changed.
|
||||
- DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
|
||||
- DELETE removed keys. Prefer `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
|
||||
- Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names.
|
||||
- Match the existing terminology and register used by each locale.
|
||||
- Prefer one Edit per file when stable, but prioritize correctness over batching.
|
||||
6. Verify only the edited files.
|
||||
- Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- <relative edited i18n file paths>`
|
||||
- Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- Run `vp run dify-web#lint:fix --quiet -- <relative edited i18n file paths under web/>`
|
||||
- Run `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- If verification fails, fix the remaining problems before continuing.
|
||||
7. Stop after the scoped locale files are updated and verification passes.
|
||||
- Do not create branches, commits, or pull requests.
|
||||
claude_args: |
|
||||
--max-turns 120
|
||||
--allowedTools "Read,Write,Edit,Bash(pnpm *),Bash(pnpm:*),Glob,Grep"
|
||||
--allowedTools "Read,Write,Edit,Bash(vp *),Bash(vp:*),Glob,Grep"
|
||||
|
||||
- name: Prepare branch metadata
|
||||
id: pr_meta
|
||||
@@ -354,6 +272,7 @@ jobs:
|
||||
- name: Create or update translation PR
|
||||
if: steps.pr_meta.outputs.has_changes == 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }}
|
||||
FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }}
|
||||
TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }}
|
||||
@@ -402,8 +321,8 @@ jobs:
|
||||
'',
|
||||
'## Verification',
|
||||
'',
|
||||
`- \`pnpm --dir web run i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
|
||||
`- \`pnpm --dir web lint:fix --quiet -- <edited i18n files>\``,
|
||||
`- \`vp run dify-web#i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
|
||||
`- \`vp run dify-web#lint:fix --quiet -- <edited i18n files under web/>\``,
|
||||
'',
|
||||
'## Notes',
|
||||
'',
|
||||
|
||||
83
.github/workflows/trigger-i18n-sync.yml
vendored
83
.github/workflows/trigger-i18n-sync.yml
vendored
@@ -42,88 +42,7 @@ jobs:
|
||||
fi
|
||||
|
||||
export BASE_SHA HEAD_SHA CHANGED_FILES
|
||||
node <<'NODE'
|
||||
const { execFileSync } = require('node:child_process')
|
||||
const fs = require('node:fs')
|
||||
const path = require('node:path')
|
||||
|
||||
const repoRoot = process.cwd()
|
||||
const baseSha = process.env.BASE_SHA || ''
|
||||
const headSha = process.env.HEAD_SHA || ''
|
||||
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
|
||||
|
||||
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
|
||||
|
||||
const readCurrentJson = (fileStem) => {
|
||||
const filePath = englishPath(fileStem)
|
||||
if (!fs.existsSync(filePath))
|
||||
return null
|
||||
|
||||
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
|
||||
}
|
||||
|
||||
const readBaseJson = (fileStem) => {
|
||||
if (!baseSha)
|
||||
return null
|
||||
|
||||
try {
|
||||
const relativePath = `web/i18n/en-US/${fileStem}.json`
|
||||
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
|
||||
return JSON.parse(content)
|
||||
}
|
||||
catch (error) {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
|
||||
|
||||
const changes = {}
|
||||
|
||||
for (const fileStem of files) {
|
||||
const beforeJson = readBaseJson(fileStem) || {}
|
||||
const afterJson = readCurrentJson(fileStem) || {}
|
||||
const added = {}
|
||||
const updated = {}
|
||||
const deleted = []
|
||||
|
||||
for (const [key, value] of Object.entries(afterJson)) {
|
||||
if (!(key in beforeJson)) {
|
||||
added[key] = value
|
||||
continue
|
||||
}
|
||||
|
||||
if (!compareJson(beforeJson[key], value)) {
|
||||
updated[key] = {
|
||||
before: beforeJson[key],
|
||||
after: value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const key of Object.keys(beforeJson)) {
|
||||
if (!(key in afterJson))
|
||||
deleted.push(key)
|
||||
}
|
||||
|
||||
changes[fileStem] = {
|
||||
fileDeleted: readCurrentJson(fileStem) === null,
|
||||
added,
|
||||
updated,
|
||||
deleted,
|
||||
}
|
||||
}
|
||||
|
||||
fs.writeFileSync(
|
||||
'/tmp/i18n-changes.json',
|
||||
JSON.stringify({
|
||||
baseSha,
|
||||
headSha,
|
||||
files,
|
||||
changes,
|
||||
})
|
||||
)
|
||||
NODE
|
||||
node .github/scripts/generate-i18n-changes.mjs
|
||||
|
||||
if [ -n "$CHANGED_FILES" ]; then
|
||||
echo "has_changes=true" >> "$GITHUB_OUTPUT"
|
||||
|
||||
@@ -81,8 +81,8 @@ if $web_modified; then
|
||||
|
||||
if $web_ts_modified; then
|
||||
echo "Running TypeScript type-check:tsgo"
|
||||
if ! pnpm run type-check:tsgo; then
|
||||
echo "Type check failed. Please run 'pnpm run type-check:tsgo' to fix the errors."
|
||||
if ! npm run type-check:tsgo; then
|
||||
echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
@@ -90,8 +90,8 @@ if $web_modified; then
|
||||
fi
|
||||
|
||||
echo "Running knip"
|
||||
if ! pnpm run knip; then
|
||||
echo "Knip check failed. Please run 'pnpm run knip' to fix the errors."
|
||||
if ! npm run knip; then
|
||||
echo "Knip check failed. Please run 'npm run knip' to fix the errors."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
79
api/controllers/common/controller_schemas.py
Normal file
79
api/controllers/common/controller_schemas.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
|
||||
# --- Conversation schemas ---
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
# --- Message schemas ---
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = None
|
||||
content: str | None = None
|
||||
|
||||
|
||||
# --- Saved message schemas ---
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
# --- Workflow schemas ---
|
||||
|
||||
|
||||
class DefaultBlockConfigQuery(BaseModel):
|
||||
q: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
# --- Audio schemas ---
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
@@ -7,7 +7,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
|
||||
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest
|
||||
@@ -26,9 +26,11 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
@@ -41,10 +43,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
NotionIcon,
|
||||
NotionInfo,
|
||||
NotionPage,
|
||||
PreProcessingRule,
|
||||
RerankingModel,
|
||||
Rule,
|
||||
Segmentation,
|
||||
WebsiteInfo,
|
||||
WeightKeywordSetting,
|
||||
WeightModel,
|
||||
@@ -155,16 +154,6 @@ class AppTracePayload(BaseModel):
|
||||
type JSONValue = Any
|
||||
|
||||
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
@@ -71,7 +71,7 @@ class AppImportApi(Resource):
|
||||
args = AppImportPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
import_service = AppDslService(session)
|
||||
# Import app
|
||||
account = current_user
|
||||
@@ -92,11 +92,13 @@ class AppImportApi(Resource):
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
match status:
|
||||
case ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
case ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports/<string:import_id>/confirm")
|
||||
|
||||
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, func, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload as _MessageFeedbackPayloadBase
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
@@ -59,10 +60,8 @@ class ChatMessagesQuery(BaseModel):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
class MessageFeedbackPayload(_MessageFeedbackPayloadBase):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
|
||||
@@ -14,6 +14,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.workflow_run import workflow_run_node_execution_model
|
||||
@@ -142,10 +143,6 @@ class PublishWorkflowPayload(BaseModel):
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class DefaultBlockConfigQuery(BaseModel):
|
||||
q: str | None = None
|
||||
|
||||
|
||||
class ConvertToWorkflowPayload(BaseModel):
|
||||
name: str | None = None
|
||||
icon_type: str | None = None
|
||||
@@ -153,18 +150,6 @@ class ConvertToWorkflowPayload(BaseModel):
|
||||
icon_background: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
node_id: str
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@@ -20,35 +20,18 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from libs.password import hash_password
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.entities.auth_entities import (
|
||||
ForgotPasswordCheckPayload,
|
||||
ForgotPasswordResetPayload,
|
||||
ForgotPasswordSendPayload,
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ForgotPasswordSendPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ForgotPasswordCheckPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
|
||||
|
||||
class ForgotPasswordResetPayload(BaseModel):
|
||||
token: str = Field(...)
|
||||
new_password: str = Field(...)
|
||||
password_confirm: str = Field(...)
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
class ForgotPasswordEmailResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
data: str | None = Field(default=None, description="Reset token")
|
||||
|
||||
@@ -42,6 +42,7 @@ from libs.token import (
|
||||
)
|
||||
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.entities.auth_entities import LoginPayloadBase
|
||||
from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
@@ -49,9 +50,7 @@ from services.feature_service import FeatureService
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class LoginPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Email address")
|
||||
password: str = Field(..., description="Password")
|
||||
class LoginPayload(LoginPayloadBase):
|
||||
remember_me: bool = Field(default=False, description="Remember me flag")
|
||||
invite_token: str | None = Field(default=None, description="Invitation token")
|
||||
|
||||
|
||||
@@ -83,11 +83,13 @@ class RagPipelineImportApi(Resource):
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
match status:
|
||||
case ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
case ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
@@ -94,22 +95,6 @@ class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
|
||||
original_document_id: str | None = None
|
||||
|
||||
|
||||
class DefaultBlockConfigQuery(BaseModel):
|
||||
q: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class NodeIdQuery(BaseModel):
|
||||
node_id: str
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@ import logging
|
||||
|
||||
from flask import request
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import TextToAudioPayload
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
@@ -32,14 +32,6 @@ from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||
|
||||
|
||||
register_schema_model(console_ns, TextToAudioPayload)
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter, model_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
@@ -32,18 +33,6 @@ class ConversationListQuery(BaseModel):
|
||||
pinned: bool | None = None
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
|
||||
|
||||
|
||||
@@ -3,9 +3,10 @@ from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
@@ -25,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
@@ -44,17 +44,6 @@ from .. import console_ns
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = None
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class MoreLikeThisQuery(BaseModel):
|
||||
response_mode: Literal["blocking", "streaming"]
|
||||
|
||||
|
||||
@@ -1,28 +1,18 @@
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic import TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.common.controller_schemas import WorkflowRunPayload
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
@@ -34,12 +33,6 @@ from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
register_schema_model(console_ns, WorkflowRunPayload)
|
||||
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantStatus
|
||||
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.billing_service import BillingService, SubscriptionPlan
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@@ -240,8 +240,10 @@ class CustomConfigWorkspaceApi(Resource):
|
||||
args = WorkspaceCustomConfigPayload.model_validate(payload)
|
||||
tenant = db.get_or_404(Tenant, current_tenant_id)
|
||||
|
||||
custom_config_dict = {
|
||||
"remove_webapp_brand": args.remove_webapp_brand,
|
||||
custom_config_dict: TenantCustomConfigDict = {
|
||||
"remove_webapp_brand": args.remove_webapp_brand
|
||||
if args.remove_webapp_brand is not None
|
||||
else tenant.custom_config_dict.get("remove_webapp_brand", False),
|
||||
"replace_webapp_logo": args.replace_webapp_logo
|
||||
if args.replace_webapp_logo is not None
|
||||
else tenant.custom_config_dict.get("replace_webapp_logo"),
|
||||
|
||||
@@ -9,7 +9,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.wraps import setup_required
|
||||
@@ -55,7 +55,7 @@ class EnterpriseAppDSLImport(Resource):
|
||||
|
||||
account.set_tenant_id(workspace_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
dsl_service = AppDslService(session)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
@@ -64,7 +64,6 @@ class EnterpriseAppDSLImport(Resource):
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
|
||||
@@ -2,11 +2,12 @@ from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
@@ -34,18 +35,6 @@ class ConversationListQuery(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
|
||||
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@@ -7,6 +6,7 @@ from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
@@ -14,7 +14,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
@@ -27,17 +26,6 @@ from services.message_service import MessageService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
|
||||
class FeedbackListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
@@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
@@ -46,9 +47,7 @@ from services.workflow_app_service import WorkflowAppService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
class WorkflowRunPayload(WorkflowRunPayloadBase):
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ from controllers.service_api.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
)
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
@@ -40,11 +41,8 @@ from models.enums import SegmentStatus
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
KnowledgeConfig,
|
||||
PreProcessingRule,
|
||||
ProcessRule,
|
||||
RetrievalModel,
|
||||
Rule,
|
||||
Segmentation,
|
||||
)
|
||||
from services.file_service import FileService
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
@@ -3,10 +3,11 @@ import logging
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import field_validator
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import TextToAudioPayload as TextToAudioPayloadBase
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
AppUnavailableError,
|
||||
@@ -34,12 +35,7 @@ from services.errors.audio import (
|
||||
from ..common.schema import register_schema_models
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
class TextToAudioPayload(TextToAudioPayloadBase):
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str | None) -> str | None:
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotChatAppError
|
||||
@@ -37,18 +38,6 @@ class ConversationListQuery(BaseModel):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@@ -19,33 +18,15 @@ from controllers.console.error import EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
||||
from controllers.web import web_ns
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.password import hash_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
class ForgotPasswordSendPayload(BaseModel):
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
|
||||
|
||||
class ForgotPasswordCheckPayload(BaseModel):
|
||||
email: EmailStr
|
||||
code: str
|
||||
token: str = Field(min_length=1)
|
||||
|
||||
|
||||
class ForgotPasswordResetPayload(BaseModel):
|
||||
token: str = Field(min_length=1)
|
||||
new_password: str
|
||||
password_confirm: str
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
from services.entities.auth_entities import (
|
||||
ForgotPasswordCheckPayload,
|
||||
ForgotPasswordResetPayload,
|
||||
ForgotPasswordSendPayload,
|
||||
)
|
||||
|
||||
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)
|
||||
|
||||
|
||||
@@ -29,13 +29,11 @@ from libs.token import (
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
from services.app_service import AppService
|
||||
from services.entities.auth_entities import LoginPayloadBase
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
class LoginPayload(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
class LoginPayload(LoginPayloadBase):
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
|
||||
@@ -6,6 +6,7 @@ from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
@@ -53,11 +54,6 @@ class MessageListQuery(BaseModel):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
|
||||
class MessageMoreLikeThisQuery(BaseModel):
|
||||
response_mode: Literal["blocking", "streaming"] = Field(
|
||||
description="Response mode",
|
||||
|
||||
@@ -138,12 +138,15 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
|
||||
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
if auth_type == WebAppAuthType.PUBLIC:
|
||||
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
|
||||
elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
|
||||
raise WebAppAuthRequiredError("Please login as external user.")
|
||||
elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
|
||||
raise WebAppAuthRequiredError("Please login as internal user.")
|
||||
match auth_type:
|
||||
case WebAppAuthType.PUBLIC:
|
||||
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
|
||||
case WebAppAuthType.EXTERNAL:
|
||||
if user_auth_type != "external":
|
||||
raise WebAppAuthRequiredError("Please login as external user.")
|
||||
case WebAppAuthType.INTERNAL:
|
||||
if user_auth_type != "internal":
|
||||
raise WebAppAuthRequiredError("Please login as internal user.")
|
||||
|
||||
end_user = None
|
||||
if end_user_id:
|
||||
|
||||
@@ -1,27 +1,17 @@
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic import TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.common.controller_schemas import WorkflowRunPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
@@ -30,12 +29,6 @@ from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the workflow")
|
||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_models(web_ns, WorkflowRunPayload)
|
||||
|
||||
@@ -79,21 +79,18 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
if not agent_scratchpad:
|
||||
assistant_messages = []
|
||||
else:
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
||||
content = ""
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
content += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
||||
content += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
||||
content += f"Action: {unit.action_str}\n\n"
|
||||
if unit.observation:
|
||||
assistant_message.content += f"Observation: {unit.observation}\n\n"
|
||||
content += f"Observation: {unit.observation}\n\n"
|
||||
|
||||
assistant_messages = [assistant_message]
|
||||
assistant_messages = [AssistantPromptMessage(content=content)]
|
||||
|
||||
# query messages
|
||||
query_messages = self._organize_user_query(self._query, [])
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
|
||||
@@ -9,6 +8,7 @@ from graphon.variables.input_entities import VariableEntity as WorkflowVariableE
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from core.rag.entities import MetadataFilteringCondition
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
@@ -111,31 +111,6 @@ class ExternalDataVariableEntity(BaseModel):
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
SupportedComparisonOperator = Literal[
|
||||
# for string or array
|
||||
"contains",
|
||||
"not contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
">",
|
||||
"<",
|
||||
"≥",
|
||||
"≤",
|
||||
# for time
|
||||
"before",
|
||||
"after",
|
||||
]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
provider: str
|
||||
name: str
|
||||
@@ -143,25 +118,6 @@ class ModelConfig(BaseModel):
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Condition detail
|
||||
"""
|
||||
|
||||
name: str
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None | int | float = None
|
||||
|
||||
|
||||
class MetadataFilteringCondition(BaseModel):
|
||||
"""
|
||||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
||||
|
||||
|
||||
class DatasetRetrieveConfigEntity(BaseModel):
|
||||
"""
|
||||
Dataset Retrieve Config Entity.
|
||||
|
||||
@@ -10,7 +10,7 @@ from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from graphon.variable_loader import VariableLoader
|
||||
from graphon.variables.variables import Variable
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@@ -363,7 +363,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
:return: List of conversation variables ready for use
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
existing_variables = self._load_existing_conversation_variables(session)
|
||||
|
||||
if not existing_variables:
|
||||
@@ -376,7 +376,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
# Convert to Variable objects for use in the workflow
|
||||
conversation_variables = [var.to_variable() for var in existing_variables]
|
||||
|
||||
session.commit()
|
||||
return cast(list[Variable], conversation_variables)
|
||||
|
||||
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
|
||||
|
||||
@@ -16,7 +16,7 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@@ -328,13 +328,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
@contextmanager
|
||||
def _database_session(self):
|
||||
"""Context manager for database sessions."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
yield session
|
||||
|
||||
def _ensure_workflow_initialized(self):
|
||||
"""Fluent validation for workflow state."""
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Union
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@@ -252,13 +252,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
@contextmanager
|
||||
def _database_session(self):
|
||||
"""Context manager for database sessions."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
yield session
|
||||
|
||||
def _ensure_workflow_initialized(self):
|
||||
"""Fluent validation for workflow state."""
|
||||
|
||||
@@ -66,7 +66,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
||||
from core.workflow.system_variables import (
|
||||
build_bootstrap_variables,
|
||||
|
||||
@@ -10,7 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChun
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
|
||||
|
||||
class QueueEvent(StrEnum):
|
||||
|
||||
@@ -9,7 +9,7 @@ from graphon.nodes.human_input.entities import FormInput, UserAction
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import ModelStatus
|
||||
@@ -73,7 +73,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
@@ -90,4 +90,3 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
@@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@@ -266,9 +266,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
err = self.handle_error(event=event, session=session, message_id=self._message_id)
|
||||
session.commit()
|
||||
yield self.error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
@@ -288,10 +287,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
answer=output_moderation_answer
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
session.commit()
|
||||
message_end_resp = self._message_end_to_stream_response()
|
||||
yield message_end_resp
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy import select, update
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -1,22 +1,3 @@
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from core.tools.entities.common_entities import I18nObject, I18nObjectDict
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
"""
|
||||
Model class for i18n object.
|
||||
"""
|
||||
|
||||
en_US: str
|
||||
zh_Hans: str | None = Field(default=None)
|
||||
pt_BR: str | None = Field(default=None)
|
||||
ja_JP: str | None = Field(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
self.zh_Hans = self.zh_Hans or self.en_US
|
||||
self.pt_BR = self.pt_BR or self.en_US
|
||||
self.ja_JP = self.ja_JP or self.en_US
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||
__all__ = ["I18nObject", "I18nObjectDict"]
|
||||
|
||||
@@ -9,7 +9,7 @@ from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.oauth import OAuthSchema
|
||||
from core.plugin.entities import OAuthSchema
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
|
||||
@@ -1 +1,8 @@
|
||||
from core.entities.plugin_credential_type import PluginCredentialType
|
||||
|
||||
DEFAULT_PLUGIN_ID = "langgenius"
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_PLUGIN_ID",
|
||||
"PluginCredentialType",
|
||||
]
|
||||
|
||||
9
api/core/entities/plugin_credential_type.py
Normal file
9
api/core/entities/plugin_credential_type.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import enum
|
||||
|
||||
|
||||
class PluginCredentialType(enum.Enum):
|
||||
MODEL = 0 # must be 0 for API contract compatibility
|
||||
TOOL = 1 # must be 1 for API contract compatibility
|
||||
|
||||
def to_number(self):
|
||||
return self.value
|
||||
@@ -22,6 +22,7 @@ from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities import PluginCredentialType
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
@@ -46,7 +47,6 @@ from models.provider import (
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Credential utility functions for checking credential existence and policy compliance.
|
||||
"""
|
||||
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from core.entities import PluginCredentialType
|
||||
|
||||
|
||||
def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:
|
||||
|
||||
@@ -17,6 +17,7 @@ from graphon.model_runtime.model_providers.__base.text_embedding_model import Te
|
||||
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities import PluginCredentialType
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
@@ -25,7 +26,6 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage
|
||||
from core.provider_manager import ProviderManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import ProviderType
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
5
api/core/plugin/entities/__init__.py
Normal file
5
api/core/plugin/entities/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from core.plugin.entities.oauth import OAuthSchema
|
||||
|
||||
__all__ = [
|
||||
"OAuthSchema",
|
||||
]
|
||||
@@ -1,5 +1,3 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
@@ -10,12 +8,12 @@ class OAuthSchema(BaseModel):
|
||||
OAuth schema
|
||||
"""
|
||||
|
||||
client_schema: Sequence[ProviderConfig] = Field(
|
||||
client_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="client schema like client_id, client_secret, etc.",
|
||||
)
|
||||
|
||||
credentials_schema: Sequence[ProviderConfig] = Field(
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="credentials schema like access_token, refresh_token, etc.",
|
||||
)
|
||||
|
||||
@@ -209,7 +209,10 @@ class PluginInstaller(BasePluginClient):
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/decode/from_identifier",
|
||||
PluginDecodeResponse,
|
||||
params={"plugin_unique_identifier": plugin_unique_identifier},
|
||||
params={
|
||||
"plugin_unique_identifier": plugin_unique_identifier,
|
||||
"PluginUniqueIdentifier": plugin_unique_identifier, # compat with daemon <= 0.5.4
|
||||
},
|
||||
)
|
||||
|
||||
def fetch_plugin_installation_by_ids(
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
@@ -15,6 +14,7 @@ from graphon.model_runtime.entities.provider_entities import (
|
||||
ProviderEntity,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -58,6 +58,8 @@ from services.feature_service import FeatureService
|
||||
if TYPE_CHECKING:
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
|
||||
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
"""
|
||||
@@ -875,8 +877,8 @@ class ProviderManager:
|
||||
return {"openai_api_key": encrypted_config}
|
||||
|
||||
try:
|
||||
credentials = cast(dict, json.loads(encrypted_config))
|
||||
except JSONDecodeError:
|
||||
credentials = _credentials_adapter.validate_json(encrypted_config)
|
||||
except (ValueError, JSONDecodeError):
|
||||
return {}
|
||||
|
||||
# Decrypt secret variables
|
||||
@@ -1015,7 +1017,7 @@ class ProviderManager:
|
||||
if not cached_provider_credentials:
|
||||
provider_credentials: dict[str, Any] = {}
|
||||
if provider_records and provider_records[0].encrypted_config:
|
||||
provider_credentials = json.loads(provider_records[0].encrypted_config)
|
||||
provider_credentials = _credentials_adapter.validate_json(provider_records[0].encrypted_config)
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
@@ -1162,8 +1164,10 @@ class ProviderManager:
|
||||
|
||||
if not cached_provider_model_credentials:
|
||||
try:
|
||||
provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_model_credentials = _credentials_adapter.validate_json(
|
||||
load_balancing_model_config.encrypted_config
|
||||
)
|
||||
except (ValueError, JSONDecodeError):
|
||||
continue
|
||||
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
@@ -1176,7 +1180,7 @@ class ProviderManager:
|
||||
if variable in provider_model_credentials:
|
||||
try:
|
||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_model_credentials.get(variable),
|
||||
provider_model_credentials.get(variable) or "",
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor,
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.entities import MetadataFilteringCondition
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
@@ -182,7 +182,9 @@ class RetrievalService:
|
||||
if not dataset:
|
||||
return []
|
||||
metadata_condition = (
|
||||
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
|
||||
MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
||||
if metadata_filtering_conditions
|
||||
else None
|
||||
)
|
||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
dataset.tenant_id,
|
||||
|
||||
@@ -37,11 +37,12 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self.analyticdb_vector._create_collection_if_not_exists(dimension)
|
||||
self.analyticdb_vector.create_collection_if_not_exists(dimension)
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
self.analyticdb_vector.add_texts(documents, embeddings)
|
||||
return []
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self.analyticdb_vector.text_exists(id)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@@ -13,6 +13,13 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class AnalyticdbClientParamsDict(TypedDict):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region_id: str
|
||||
read_timeout: int
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
@@ -44,13 +51,14 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
def to_analyticdb_client_params(self) -> AnalyticdbClientParamsDict:
|
||||
result: AnalyticdbClientParamsDict = {
|
||||
"access_key_id": self.access_key_id,
|
||||
"access_key_secret": self.access_key_secret,
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPI:
|
||||
@@ -115,7 +123,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
else:
|
||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
def create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
@@ -74,7 +75,7 @@ class AnalyticdbVectorBySql:
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
def _get_cursor(self) -> Iterator[Any]:
|
||||
assert self.pool is not None, "Connection pool is not initialized"
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
@@ -130,7 +131,7 @@ class AnalyticdbVectorBySql:
|
||||
)
|
||||
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
def create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
|
||||
@@ -30,7 +30,7 @@ from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams,
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field as VDBField
|
||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
@@ -85,8 +85,12 @@ class BaiduVector(BaseVector):
|
||||
def get_type(self) -> str:
|
||||
return VectorType.BAIDU
|
||||
|
||||
def to_index_struct(self):
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
def to_index_struct(self) -> VectorIndexStructDict:
|
||||
result: VectorIndexStructDict = {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name},
|
||||
}
|
||||
return result
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self._create_table(len(embeddings[0]))
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import chromadb
|
||||
from chromadb import QueryResult, Settings
|
||||
from chromadb import QueryResult, Settings # pyright: ignore[reportPrivateImportUsage]
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
@@ -15,6 +15,15 @@ from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class ChromaParamsDict(TypedDict):
|
||||
host: str
|
||||
port: int
|
||||
ssl: bool
|
||||
tenant: str
|
||||
database: str
|
||||
settings: Settings
|
||||
|
||||
|
||||
class ChromaConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
@@ -23,14 +32,13 @@ class ChromaConfig(BaseModel):
|
||||
auth_provider: str | None = None
|
||||
auth_credentials: str | None = None
|
||||
|
||||
def to_chroma_params(self):
|
||||
def to_chroma_params(self) -> ChromaParamsDict:
|
||||
settings = Settings(
|
||||
# auth
|
||||
chroma_client_auth_provider=self.auth_provider,
|
||||
chroma_client_auth_credentials=self.auth_credentials,
|
||||
)
|
||||
|
||||
return {
|
||||
result: ChromaParamsDict = {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"ssl": False,
|
||||
@@ -38,6 +46,7 @@ class ChromaConfig(BaseModel):
|
||||
"database": self.database,
|
||||
"settings": settings,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class ChromaVector(BaseVector):
|
||||
@@ -97,14 +106,15 @@ class ChromaVector(BaseVector):
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
results: QueryResult
|
||||
if document_ids_filter:
|
||||
results: QueryResult = collection.query(
|
||||
results = collection.query(
|
||||
query_embeddings=query_vector,
|
||||
n_results=kwargs.get("top_k", 4),
|
||||
where={"document_id": {"$in": document_ids_filter}}, # type: ignore
|
||||
)
|
||||
else:
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
|
||||
results = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
# Check if results contain data
|
||||
@@ -145,7 +155,10 @@ class ChromaVectorFactory(AbstractVectorFactory):
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}}
|
||||
index_struct_dict: VectorIndexStructDict = {
|
||||
"type": VectorType.CHROMA,
|
||||
"vector_store": {"class_prefix": collection_name},
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
return ChromaVector(
|
||||
@@ -153,8 +166,8 @@ class ChromaVectorFactory(AbstractVectorFactory):
|
||||
config=ChromaConfig(
|
||||
host=dify_config.CHROMA_HOST or "",
|
||||
port=dify_config.CHROMA_PORT,
|
||||
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
|
||||
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
|
||||
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, # pyright: ignore[reportPrivateImportUsage]
|
||||
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, # pyright: ignore[reportPrivateImportUsage]
|
||||
auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
|
||||
auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
|
||||
),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -20,6 +20,15 @@ from models.dataset import Dataset
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MilvusParamsDict(TypedDict):
|
||||
uri: str
|
||||
token: str | None
|
||||
user: str | None
|
||||
password: str | None
|
||||
db_name: str
|
||||
analyzer_params: str | None
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
"""
|
||||
Configuration class for Milvus connection.
|
||||
@@ -50,11 +59,11 @@ class MilvusConfig(BaseModel):
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_milvus_params(self):
|
||||
def to_milvus_params(self) -> MilvusParamsDict:
|
||||
"""
|
||||
Convert the configuration to a dictionary of Milvus connection parameters.
|
||||
"""
|
||||
return {
|
||||
result: MilvusParamsDict = {
|
||||
"uri": self.uri,
|
||||
"token": self.token,
|
||||
"user": self.user,
|
||||
@@ -62,6 +71,7 @@ class MilvusConfig(BaseModel):
|
||||
"db_name": self.database,
|
||||
"analyzer_params": self.analyzer_params,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class MilvusVector(BaseVector):
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import qdrant_client
|
||||
from flask import current_app
|
||||
@@ -22,7 +22,7 @@ from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
@@ -32,7 +32,6 @@ from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
from qdrant_client.conversions import common_types
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
@@ -94,8 +93,12 @@ class QdrantVector(BaseVector):
|
||||
def get_type(self) -> str:
|
||||
return VectorType.QDRANT
|
||||
|
||||
def to_index_struct(self):
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
def to_index_struct(self) -> VectorIndexStructDict:
|
||||
result: VectorIndexStructDict = {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name},
|
||||
}
|
||||
return result
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
if texts:
|
||||
@@ -176,7 +179,7 @@ class QdrantVector(BaseVector):
|
||||
for batch_ids, points in self._generate_rest_batches(
|
||||
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
|
||||
):
|
||||
self._client.upsert(collection_name=self._collection_name, points=points)
|
||||
self._client.upsert(collection_name=self._collection_name, points=cast("common_types.Points", points))
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
return added_ids
|
||||
@@ -468,7 +471,7 @@ class QdrantVector(BaseVector):
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client, QdrantLocal):
|
||||
self._client._load()
|
||||
self._client._load() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point(
|
||||
|
||||
@@ -26,7 +26,7 @@ from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Base = declarative_base() # type: Any
|
||||
Base: Any = declarative_base()
|
||||
|
||||
|
||||
class RelytConfig(BaseModel):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
from tcvdb_text.encoder import BM25Encoder # type: ignore
|
||||
@@ -12,7 +12,7 @@ from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, Weighted
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
@@ -23,6 +23,13 @@ from models.dataset import Dataset
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TencentParamsDict(TypedDict):
|
||||
url: str
|
||||
username: str | None
|
||||
key: str | None
|
||||
timeout: float
|
||||
|
||||
|
||||
class TencentConfig(BaseModel):
|
||||
url: str
|
||||
api_key: str | None = None
|
||||
@@ -36,8 +43,14 @@ class TencentConfig(BaseModel):
|
||||
max_upsert_batch_size: int = 128
|
||||
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
||||
|
||||
def to_tencent_params(self):
|
||||
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
|
||||
def to_tencent_params(self) -> TencentParamsDict:
|
||||
result: TencentParamsDict = {
|
||||
"url": self.url,
|
||||
"username": self.username,
|
||||
"key": self.api_key,
|
||||
"timeout": self.timeout,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
bm25 = BM25Encoder.default("zh")
|
||||
@@ -83,8 +96,12 @@ class TencentVector(BaseVector):
|
||||
def get_type(self) -> str:
|
||||
return VectorType.TENCENT
|
||||
|
||||
def to_index_struct(self):
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
def to_index_struct(self) -> VectorIndexStructDict:
|
||||
result: VectorIndexStructDict = {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name},
|
||||
}
|
||||
return result
|
||||
|
||||
def _has_collection(self) -> bool:
|
||||
return bool(
|
||||
|
||||
@@ -25,7 +25,7 @@ from sqlalchemy import select
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
@@ -91,8 +91,12 @@ class TidbOnQdrantVector(BaseVector):
|
||||
def get_type(self) -> str:
|
||||
return VectorType.TIDB_ON_QDRANT
|
||||
|
||||
def to_index_struct(self):
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
def to_index_struct(self) -> VectorIndexStructDict:
|
||||
result: VectorIndexStructDict = {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name},
|
||||
}
|
||||
return result
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
if texts:
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class VectorStoreDict(TypedDict):
|
||||
class_prefix: str
|
||||
|
||||
|
||||
class VectorIndexStructDict(TypedDict):
|
||||
type: str
|
||||
vector_store: VectorStoreDict
|
||||
|
||||
|
||||
class BaseVector(ABC):
|
||||
def __init__(self, collection_name: str):
|
||||
self._collection_name = collection_name
|
||||
|
||||
@@ -9,7 +9,7 @@ from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
@@ -30,8 +30,11 @@ class AbstractVectorFactory(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def gen_index_struct_dict(vector_type: VectorType, collection_name: str):
|
||||
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
|
||||
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> VectorIndexStructDict:
|
||||
index_struct_dict: VectorIndexStructDict = {
|
||||
"type": vector_type,
|
||||
"vector_store": {"class_prefix": collection_name},
|
||||
}
|
||||
return index_struct_dict
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from weaviate.exceptions import UnexpectedStatusCodeError
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
@@ -184,9 +184,13 @@ class WeaviateVector(BaseVector):
|
||||
dataset_id = dataset.id
|
||||
return Dataset.gen_collection_name_by_id(dataset_id)
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
def to_index_struct(self) -> VectorIndexStructDict:
|
||||
"""Returns the index structure dictionary for persistence."""
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
result: VectorIndexStructDict = {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name},
|
||||
}
|
||||
return result
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""
|
||||
|
||||
28
api/core/rag/entities/__init__.py
Normal file
28
api/core/rag/entities/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.event import DatasourceCompletedEvent, DatasourceErrorEvent, DatasourceProcessingEvent
|
||||
from core.rag.entities.index_entities import EconomySetting, EmbeddingSetting, IndexMethod
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition, SupportedComparisonOperator
|
||||
from core.rag.entities.processing_entities import ParentMode, PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.entities.retrieval_settings import KeywordSetting, VectorSetting, WeightedScoreConfig
|
||||
|
||||
__all__ = [
|
||||
"Condition",
|
||||
"DatasourceCompletedEvent",
|
||||
"DatasourceErrorEvent",
|
||||
"DatasourceProcessingEvent",
|
||||
"DocumentContext",
|
||||
"EconomySetting",
|
||||
"EmbeddingSetting",
|
||||
"IndexMethod",
|
||||
"KeywordSetting",
|
||||
"MetadataFilteringCondition",
|
||||
"ParentMode",
|
||||
"PreProcessingRule",
|
||||
"RetrievalSourceMetadata",
|
||||
"Rule",
|
||||
"Segmentation",
|
||||
"SupportedComparisonOperator",
|
||||
"VectorSetting",
|
||||
"WeightedScoreConfig",
|
||||
]
|
||||
30
api/core/rag/entities/index_entities.py
Normal file
30
api/core/rag/entities/index_entities.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class EmbeddingSetting(BaseModel):
|
||||
"""
|
||||
Embedding Setting.
|
||||
"""
|
||||
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class EconomySetting(BaseModel):
|
||||
"""
|
||||
Economy Setting.
|
||||
"""
|
||||
|
||||
keyword_number: int
|
||||
|
||||
|
||||
class IndexMethod(BaseModel):
|
||||
"""
|
||||
Knowledge Index Setting.
|
||||
"""
|
||||
|
||||
indexing_technique: Literal["high_quality", "economy"]
|
||||
embedding_setting: EmbeddingSetting
|
||||
economy_setting: EconomySetting
|
||||
@@ -38,9 +38,9 @@ class Condition(BaseModel):
|
||||
value: str | Sequence[str] | None | int | float = None
|
||||
|
||||
|
||||
class MetadataCondition(BaseModel):
|
||||
class MetadataFilteringCondition(BaseModel):
|
||||
"""
|
||||
Metadata Condition.
|
||||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
|
||||
27
api/core/rag/entities/processing_entities.py
Normal file
27
api/core/rag/entities/processing_entities.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ParentMode(StrEnum):
|
||||
FULL_DOC = "full-doc"
|
||||
PARAGRAPH = "paragraph"
|
||||
|
||||
|
||||
class PreProcessingRule(BaseModel):
|
||||
id: str
|
||||
enabled: bool
|
||||
|
||||
|
||||
class Segmentation(BaseModel):
|
||||
separator: str = "\n"
|
||||
max_tokens: int
|
||||
chunk_overlap: int = 0
|
||||
|
||||
|
||||
class Rule(BaseModel):
|
||||
pre_processing_rules: list[PreProcessingRule] | None = None
|
||||
segmentation: Segmentation | None = None
|
||||
parent_mode: Literal["full-doc", "paragraph"] | None = None
|
||||
subchunk_segmentation: Segmentation | None = None
|
||||
28
api/core/rag/entities/retrieval_settings.py
Normal file
28
api/core/rag/entities/retrieval_settings.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
"""
|
||||
Vector Setting.
|
||||
"""
|
||||
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSetting(BaseModel):
|
||||
"""
|
||||
Keyword Setting.
|
||||
"""
|
||||
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightedScoreConfig(BaseModel):
|
||||
"""
|
||||
Weighted score Config.
|
||||
"""
|
||||
|
||||
vector_setting: VectorSetting
|
||||
keyword_setting: KeywordSetting
|
||||
@@ -19,12 +19,15 @@ class UnstructuredWordExtractor(BaseExtractor):
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.__version__ import __version__ as __unstructured_version__
|
||||
from unstructured.file_utils.filetype import FileType, detect_filetype
|
||||
from unstructured.file_utils.filetype import ( # pyright: ignore[reportPrivateImportUsage]
|
||||
FileType,
|
||||
detect_filetype,
|
||||
)
|
||||
|
||||
unstructured_version = tuple(int(x) for x in __unstructured_version__.split("."))
|
||||
# check the file extension
|
||||
try:
|
||||
import magic # noqa: F401
|
||||
import magic # noqa: F401 # pyright: ignore[reportUnusedImport]
|
||||
|
||||
is_doc = detect_filetype(self._file_path) == FileType.DOC
|
||||
except ImportError:
|
||||
|
||||
@@ -12,7 +12,7 @@ from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
|
||||
from core.workflow.nodes.knowledge_index.protocols import IndexingResultDict, Preview, PreviewItem, QaPreview
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
from .index_processor_factory import IndexProcessorFactory
|
||||
@@ -61,7 +61,7 @@ class IndexProcessor:
|
||||
chunks: Mapping[str, Any],
|
||||
batch: Any,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
):
|
||||
) -> IndexingResultDict:
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if not document:
|
||||
@@ -129,7 +129,7 @@ class IndexProcessor:
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
result: IndexingResultDict = {
|
||||
"dataset_id": dataset_id,
|
||||
"dataset_name": dataset_name_value,
|
||||
"batch": batch,
|
||||
@@ -138,6 +138,7 @@ class IndexProcessor:
|
||||
"created_at": created_at_value.timestamp(),
|
||||
"display_status": "completed",
|
||||
}
|
||||
return result
|
||||
|
||||
def get_preview_output(
|
||||
self,
|
||||
|
||||
@@ -32,6 +32,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.entities import Rule
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@@ -49,7 +50,6 @@ from models.account import Account
|
||||
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
@@ -17,6 +17,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.entities import ParentMode, Rule
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@@ -30,7 +31,6 @@ from models import Account
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -19,6 +19,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.entities import Rule
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
@@ -30,7 +31,6 @@ from libs import helper
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,16 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
vector_weight: float
|
||||
|
||||
embedding_provider_name: str
|
||||
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSetting(BaseModel):
|
||||
keyword_weight: float
|
||||
from core.rag.entities import KeywordSetting, VectorSetting
|
||||
|
||||
|
||||
class Weights(BaseModel):
|
||||
|
||||
@@ -15,7 +15,7 @@ from graphon.model_runtime.entities.message_entities import PromptMessage, Promp
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@@ -39,9 +39,7 @@ from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.entities import Condition, DocumentContext, RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
@@ -604,7 +602,7 @@ class DatasetRetrieval:
|
||||
planning_strategy: PlanningStrategy,
|
||||
message_id: str | None = None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
metadata_condition: MetadataFilteringCondition | None = None,
|
||||
):
|
||||
tools = []
|
||||
for dataset in available_datasets:
|
||||
@@ -743,7 +741,7 @@ class DatasetRetrieval:
|
||||
reranking_enable: bool = True,
|
||||
message_id: str | None = None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
metadata_condition: MetadataFilteringCondition | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
):
|
||||
if not available_datasets:
|
||||
@@ -886,7 +884,7 @@ class DatasetRetrieval:
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
return
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# Collect all document_ids and batch fetch DatasetDocuments
|
||||
document_ids = {
|
||||
doc.metadata["document_id"]
|
||||
@@ -977,7 +975,6 @@ class DatasetRetrieval:
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
|
||||
@@ -1063,7 +1060,7 @@ class DatasetRetrieval:
|
||||
top_k: int,
|
||||
all_documents: list[Document],
|
||||
document_ids_filter: list[str] | None = None,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
metadata_condition: MetadataFilteringCondition | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
@@ -1339,7 +1336,7 @@ class DatasetRetrieval:
|
||||
metadata_model_config: ModelConfig,
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None,
|
||||
inputs: dict,
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]:
|
||||
document_query = select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id.in_(dataset_ids),
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
@@ -1371,7 +1368,7 @@ class DatasetRetrieval:
|
||||
value=filter.get("value"),
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
metadata_condition = MetadataFilteringCondition(
|
||||
logical_operator=metadata_filtering_conditions.logical_operator
|
||||
if metadata_filtering_conditions
|
||||
else "or", # type: ignore
|
||||
@@ -1400,7 +1397,7 @@ class DatasetRetrieval:
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
metadata_condition = MetadataFilteringCondition(
|
||||
logical_operator=metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
@@ -1723,7 +1720,7 @@ class DatasetRetrieval:
|
||||
self,
|
||||
flask_app: Flask,
|
||||
available_datasets: list[Dataset],
|
||||
metadata_condition: MetadataCondition | None,
|
||||
metadata_condition: MetadataFilteringCondition | None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None,
|
||||
all_documents: list[Document],
|
||||
tenant_id: str,
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class I18nObjectDict(TypedDict):
|
||||
zh_Hans: str | None
|
||||
en_US: str
|
||||
pt_BR: str | None
|
||||
ja_JP: str | None
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
"""
|
||||
Model class for i18n object.
|
||||
@@ -18,5 +27,11 @@ class I18nObject(BaseModel):
|
||||
self.ja_JP = self.ja_JP or self.en_US
|
||||
return self
|
||||
|
||||
def to_dict(self):
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||
def to_dict(self) -> I18nObjectDict:
|
||||
result: I18nObjectDict = {
|
||||
"zh_Hans": self.zh_Hans,
|
||||
"en_US": self.en_US,
|
||||
"pt_BR": self.pt_BR,
|
||||
"ja_JP": self.ja_JP,
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -6,9 +6,20 @@ from collections.abc import Mapping
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
TypeAdapter,
|
||||
ValidationInfo,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities import OAuthSchema
|
||||
from core.plugin.entities.parameters import (
|
||||
MCPServerParameterType,
|
||||
PluginParameter,
|
||||
@@ -18,11 +29,19 @@ from core.plugin.entities.parameters import (
|
||||
cast_parameter_value,
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
|
||||
class EmojiIconDict(TypedDict):
|
||||
background: str
|
||||
content: str
|
||||
|
||||
|
||||
emoji_icon_adapter: TypeAdapter[EmojiIconDict] = TypeAdapter(EmojiIconDict)
|
||||
|
||||
|
||||
class ToolLabelEnum(StrEnum):
|
||||
SEARCH = "search"
|
||||
IMAGE = "image"
|
||||
@@ -410,15 +429,6 @@ class ToolEntity(BaseModel):
|
||||
return value or {}
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
client_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list[ProviderConfig], description="The schema of the OAuth client"
|
||||
)
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list[ProviderConfig], description="The schema of the OAuth credentials"
|
||||
)
|
||||
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
plugin_id: str | None = None
|
||||
|
||||
@@ -5,16 +5,19 @@ import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from graphon.runtime import VariablePool
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import TypedDict
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.entities import PluginCredentialType
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
@@ -27,7 +30,6 @@ from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from extensions.ext_database import db
|
||||
from models.provider_ids import ToolProviderID
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -49,9 +51,11 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
EmojiIconDict,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
emoji_icon_adapter,
|
||||
)
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
@@ -72,9 +76,7 @@ class ApiProviderControllerItem(TypedDict):
|
||||
controller: ApiToolProviderController
|
||||
|
||||
|
||||
class EmojiIconDict(TypedDict):
|
||||
background: str
|
||||
content: str
|
||||
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
|
||||
|
||||
|
||||
class WorkflowToolRuntimeSpec(Protocol):
|
||||
@@ -203,16 +205,160 @@ class ToolManager:
|
||||
|
||||
:return: the tool
|
||||
"""
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
# check if the builtin tool need credentials
|
||||
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
||||
match provider_type:
|
||||
case ToolProviderType.BUILT_IN:
|
||||
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
||||
|
||||
builtin_tool = provider_controller.get_tool(tool_name)
|
||||
if not builtin_tool:
|
||||
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
|
||||
builtin_tool = provider_controller.get_tool(tool_name)
|
||||
if not builtin_tool:
|
||||
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
|
||||
|
||||
if not provider_controller.need_credentials:
|
||||
return builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
builtin_provider = None
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
provider_id_entity = ToolProviderID(provider_id)
|
||||
if is_valid_uuid(credential_id):
|
||||
try:
|
||||
builtin_provider_stmt = select(BuiltinToolProvider).where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
builtin_provider = db.session.scalar(builtin_provider_stmt)
|
||||
except Exception as e:
|
||||
builtin_provider = None
|
||||
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
|
||||
|
||||
if builtin_provider is None:
|
||||
with Session(db.engine) as session:
|
||||
builtin_provider = session.scalar(
|
||||
sa.select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
)
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||
else:
|
||||
builtin_provider = db.session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(
|
||||
credential_id=builtin_provider.id,
|
||||
provider=provider_id,
|
||||
credential_type=PluginCredentialType.TOOL,
|
||||
check_existence=False,
|
||||
)
|
||||
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
|
||||
],
|
||||
cache=ToolProviderCredentialsCache(
|
||||
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
||||
),
|
||||
)
|
||||
|
||||
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
|
||||
|
||||
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
||||
# TODO: circular import
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
tool_provider = ToolProviderID(provider_id)
|
||||
provider_name = tool_provider.provider_name
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
||||
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=builtin_provider.user_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=decrypted_credentials,
|
||||
)
|
||||
# update the credentials
|
||||
builtin_provider.encrypted_credentials = json.dumps(
|
||||
encrypter.encrypt(refreshed_credentials.credentials)
|
||||
)
|
||||
builtin_provider.expires_at = refreshed_credentials.expires_at
|
||||
db.session.commit()
|
||||
decrypted_credentials = refreshed_credentials.credentials
|
||||
cache.delete()
|
||||
|
||||
if not provider_controller.need_credentials:
|
||||
return builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials=dict(decrypted_credentials),
|
||||
credential_type=builtin_provider.credential_type,
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
|
||||
case ToolProviderType.API:
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
controller=api_provider,
|
||||
)
|
||||
return api_provider.get_tool(tool_name).fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials=dict(encrypter.decrypt(credentials)),
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
case ToolProviderType.WORKFLOW:
|
||||
workflow_provider_stmt = select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
workflow_provider = session.scalar(workflow_provider_stmt)
|
||||
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
|
||||
if controller_tools is None or len(controller_tools) == 0:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
@@ -221,177 +367,28 @@ class ToolManager:
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
builtin_provider = None
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
provider_id_entity = ToolProviderID(provider_id)
|
||||
# get specific credentials
|
||||
if is_valid_uuid(credential_id):
|
||||
try:
|
||||
builtin_provider_stmt = select(BuiltinToolProvider).where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
builtin_provider = db.session.scalar(builtin_provider_stmt)
|
||||
except Exception as e:
|
||||
builtin_provider = None
|
||||
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
|
||||
# if the provider has been deleted, raise an error
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
|
||||
|
||||
# fallback to the default provider
|
||||
if builtin_provider is None:
|
||||
# use the default provider
|
||||
with Session(db.engine) as session:
|
||||
builtin_provider = session.scalar(
|
||||
sa.select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
)
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||
else:
|
||||
builtin_provider = db.session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||
|
||||
# check if the credential is allowed to be used
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(
|
||||
credential_id=builtin_provider.id,
|
||||
provider=provider_id,
|
||||
credential_type=PluginCredentialType.TOOL,
|
||||
check_existence=False,
|
||||
)
|
||||
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
|
||||
],
|
||||
cache=ToolProviderCredentialsCache(
|
||||
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
||||
),
|
||||
)
|
||||
|
||||
# decrypt the credentials
|
||||
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
|
||||
|
||||
# check if the credentials is expired
|
||||
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
||||
# TODO: circular import
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
# refresh the credentials
|
||||
tool_provider = ToolProviderID(provider_id)
|
||||
provider_name = tool_provider.provider_name
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
||||
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
# refresh the credentials
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=builtin_provider.user_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=decrypted_credentials,
|
||||
)
|
||||
# update the credentials
|
||||
builtin_provider.encrypted_credentials = json.dumps(
|
||||
encrypter.encrypt(refreshed_credentials.credentials)
|
||||
)
|
||||
builtin_provider.expires_at = refreshed_credentials.expires_at
|
||||
db.session.commit()
|
||||
decrypted_credentials = refreshed_credentials.credentials
|
||||
cache.delete()
|
||||
|
||||
return builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials=dict(decrypted_credentials),
|
||||
credential_type=builtin_provider.credential_type,
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
|
||||
elif provider_type == ToolProviderType.API:
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
controller=api_provider,
|
||||
)
|
||||
return api_provider.get_tool(tool_name).fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials=dict(encrypter.decrypt(credentials)),
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
workflow_provider_stmt = select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
workflow_provider = session.scalar(workflow_provider_stmt)
|
||||
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
|
||||
if controller_tools is None or len(controller_tools) == 0:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
elif provider_type == ToolProviderType.APP:
|
||||
raise NotImplementedError("app provider not implemented")
|
||||
elif provider_type == ToolProviderType.PLUGIN:
|
||||
plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
|
||||
runtime = getattr(plugin_tool, "runtime", None)
|
||||
if runtime is not None:
|
||||
runtime.user_id = user_id
|
||||
runtime.invoke_from = invoke_from
|
||||
runtime.tool_invoke_from = tool_invoke_from
|
||||
return plugin_tool
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
|
||||
runtime = getattr(mcp_tool, "runtime", None)
|
||||
if runtime is not None:
|
||||
runtime.user_id = user_id
|
||||
runtime.invoke_from = invoke_from
|
||||
runtime.tool_invoke_from = tool_invoke_from
|
||||
return mcp_tool
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
|
||||
case ToolProviderType.APP:
|
||||
raise NotImplementedError("app provider not implemented")
|
||||
case ToolProviderType.PLUGIN:
|
||||
plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
|
||||
runtime = getattr(plugin_tool, "runtime", None)
|
||||
if runtime is not None:
|
||||
runtime.user_id = user_id
|
||||
runtime.invoke_from = invoke_from
|
||||
runtime.tool_invoke_from = tool_invoke_from
|
||||
return plugin_tool
|
||||
case ToolProviderType.MCP:
|
||||
mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
|
||||
runtime = getattr(mcp_tool, "runtime", None)
|
||||
if runtime is not None:
|
||||
runtime.user_id = user_id
|
||||
runtime.invoke_from = invoke_from
|
||||
runtime.tool_invoke_from = tool_invoke_from
|
||||
return mcp_tool
|
||||
case ToolProviderType.DATASET_RETRIEVAL:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
|
||||
case _:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type} not found")
|
||||
|
||||
@classmethod
|
||||
def get_agent_tool_runtime(
|
||||
@@ -885,7 +882,7 @@ class ToolManager:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
try:
|
||||
credentials = json.loads(provider_obj.credentials_str) or {}
|
||||
credentials = _credentials_adapter.validate_json(provider_obj.credentials_str) or {}
|
||||
except Exception:
|
||||
credentials = {}
|
||||
|
||||
@@ -910,7 +907,7 @@ class ToolManager:
|
||||
masked_credentials = encrypter.mask_plugin_credentials(encrypter.decrypt(credentials))
|
||||
|
||||
try:
|
||||
icon = json.loads(provider_obj.icon)
|
||||
icon = emoji_icon_adapter.validate_json(provider_obj.icon)
|
||||
except Exception:
|
||||
icon = {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@@ -973,7 +970,7 @@ class ToolManager:
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
icon = json.loads(workflow_provider.icon)
|
||||
icon = emoji_icon_adapter.validate_json(workflow_provider.icon)
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
@@ -990,7 +987,7 @@ class ToolManager:
|
||||
if api_provider is None:
|
||||
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
|
||||
|
||||
icon = json.loads(api_provider.icon)
|
||||
icon = emoji_icon_adapter.validate_json(api_provider.icon)
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
@@ -1025,31 +1022,31 @@ class ToolManager:
|
||||
:param provider_id: the id of the provider
|
||||
:return:
|
||||
"""
|
||||
provider_type = provider_type
|
||||
provider_id = provider_id
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
match provider_type:
|
||||
case ToolProviderType.BUILT_IN:
|
||||
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
return cls.generate_builtin_tool_icon_url(provider_id)
|
||||
case ToolProviderType.API:
|
||||
return cls.generate_api_tool_icon_url(tenant_id, provider_id)
|
||||
case ToolProviderType.WORKFLOW:
|
||||
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
|
||||
case ToolProviderType.PLUGIN:
|
||||
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
return cls.generate_builtin_tool_icon_url(provider_id)
|
||||
elif provider_type == ToolProviderType.API:
|
||||
return cls.generate_api_tool_icon_url(tenant_id, provider_id)
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
|
||||
elif provider_type == ToolProviderType.PLUGIN:
|
||||
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
raise ValueError(f"plugin provider {provider_id} not found")
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
case ToolProviderType.MCP:
|
||||
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
|
||||
case ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
case _:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
@classmethod
|
||||
def _convert_tool_parameters_type(
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy import select
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
|
||||
@@ -6,8 +6,7 @@ from sqlalchemy import select
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities import DocumentContext, RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
|
||||
@@ -305,14 +305,15 @@ class WorkflowTool(Tool):
|
||||
"transfer_method": file.transfer_method.value,
|
||||
"type": file.type.value,
|
||||
}
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = resolve_file_record_id(file.reference)
|
||||
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = resolve_file_record_id(file.reference)
|
||||
elif file.transfer_method == FileTransferMethod.DATASOURCE_FILE:
|
||||
file_dict["datasource_file_id"] = resolve_file_record_id(file.reference)
|
||||
elif file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
file_dict["url"] = file.generate_url()
|
||||
match file.transfer_method:
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = resolve_file_record_id(file.reference)
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = resolve_file_record_id(file.reference)
|
||||
case FileTransferMethod.DATASOURCE_FILE:
|
||||
file_dict["datasource_file_id"] = resolve_file_record_id(file.reference)
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
file_dict["url"] = file.generate_url()
|
||||
|
||||
files.append(file_dict)
|
||||
except Exception:
|
||||
@@ -357,8 +358,11 @@ class WorkflowTool(Tool):
|
||||
def _update_file_mapping(self, file_dict: dict):
|
||||
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
|
||||
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
|
||||
if transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file_id
|
||||
elif transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = file_id
|
||||
match transfer_method:
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file_id
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = file_id
|
||||
case FileTransferMethod.REMOTE_URL | FileTransferMethod.DATASOURCE_FILE:
|
||||
pass
|
||||
return file_dict
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Union
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities import OAuthSchema
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameterAutoGenerate,
|
||||
PluginParameterOption,
|
||||
@@ -108,13 +109,6 @@ class EventEntity(BaseModel):
|
||||
return v or []
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="The schema of the OAuth credentials"
|
||||
)
|
||||
|
||||
|
||||
class SubscriptionConstructor(BaseModel):
|
||||
"""
|
||||
The subscription constructor of the trigger provider
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Literal, Union
|
||||
from typing import Union
|
||||
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
from graphon.enums import NodeType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.entities.retrieval_settings import WeightedScoreConfig
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
@@ -18,50 +19,6 @@ class RerankingModelConfig(BaseModel):
|
||||
reranking_model_name: str
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
"""
|
||||
Vector Setting.
|
||||
"""
|
||||
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSetting(BaseModel):
|
||||
"""
|
||||
Keyword Setting.
|
||||
"""
|
||||
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightedScoreConfig(BaseModel):
|
||||
"""
|
||||
Weighted score Config.
|
||||
"""
|
||||
|
||||
vector_setting: VectorSetting
|
||||
keyword_setting: KeywordSetting
|
||||
|
||||
|
||||
class EmbeddingSetting(BaseModel):
|
||||
"""
|
||||
Embedding Setting.
|
||||
"""
|
||||
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class EconomySetting(BaseModel):
|
||||
"""
|
||||
Economy Setting.
|
||||
"""
|
||||
|
||||
keyword_number: int
|
||||
|
||||
|
||||
class RetrievalSetting(BaseModel):
|
||||
"""
|
||||
Retrieval Setting.
|
||||
@@ -77,16 +34,6 @@ class RetrievalSetting(BaseModel):
|
||||
weights: WeightedScoreConfig | None = None
|
||||
|
||||
|
||||
class IndexMethod(BaseModel):
|
||||
"""
|
||||
Knowledge Index Setting.
|
||||
"""
|
||||
|
||||
indexing_technique: Literal["high_quality", "economy"]
|
||||
embedding_setting: EmbeddingSetting
|
||||
economy_setting: EconomySetting
|
||||
|
||||
|
||||
class FileInfo(BaseModel):
|
||||
"""
|
||||
File Info.
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
from typing import Any, Protocol, TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class IndexingResultDict(TypedDict):
|
||||
dataset_id: str
|
||||
dataset_name: str
|
||||
batch: Any
|
||||
document_id: str
|
||||
document_name: str
|
||||
created_at: float
|
||||
display_status: str
|
||||
|
||||
|
||||
class PreviewItem(BaseModel):
|
||||
content: str | None = Field(default=None)
|
||||
child_chunks: list[str] | None = Field(default=None)
|
||||
@@ -34,7 +44,7 @@ class IndexProcessorProtocol(Protocol):
|
||||
chunks: Mapping[str, Any],
|
||||
batch: Any,
|
||||
summary_index_setting: dict | None = None,
|
||||
) -> dict[str, Any]: ...
|
||||
) -> IndexingResultDict: ...
|
||||
|
||||
def get_preview_output(
|
||||
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
@@ -6,6 +5,10 @@ from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.entities import Condition, MetadataFilteringCondition, WeightedScoreConfig
|
||||
|
||||
__all__ = ["Condition"]
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
@@ -16,33 +19,6 @@ class RerankingModelConfig(BaseModel):
|
||||
model: str
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
"""
|
||||
Vector Setting.
|
||||
"""
|
||||
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSetting(BaseModel):
|
||||
"""
|
||||
Keyword Setting.
|
||||
"""
|
||||
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightedScoreConfig(BaseModel):
|
||||
"""
|
||||
Weighted score Config.
|
||||
"""
|
||||
|
||||
vector_setting: VectorSetting
|
||||
keyword_setting: KeywordSetting
|
||||
|
||||
|
||||
class MultipleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
Multiple Retrieval Config.
|
||||
@@ -64,50 +40,6 @@ class SingleRetrievalConfig(BaseModel):
|
||||
model: ModelConfig
|
||||
|
||||
|
||||
SupportedComparisonOperator = Literal[
|
||||
# for string or array
|
||||
"contains",
|
||||
"not contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
">",
|
||||
"<",
|
||||
"≥",
|
||||
"≤",
|
||||
# for time
|
||||
"before",
|
||||
"after",
|
||||
]
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Condition detail
|
||||
"""
|
||||
|
||||
name: str
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None | int | float = None
|
||||
|
||||
|
||||
class MetadataFilteringCondition(BaseModel):
|
||||
"""
|
||||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.entities.graph_config import NodeConfigDictAdapter
|
||||
@@ -107,6 +107,26 @@ class _WorkflowChildEngineBuilder:
|
||||
return child_engine
|
||||
|
||||
|
||||
class _NodeConfigDict(TypedDict):
|
||||
id: str
|
||||
width: int
|
||||
height: int
|
||||
type: str
|
||||
data: dict[str, Any]
|
||||
|
||||
|
||||
class _EdgeConfigDict(TypedDict):
|
||||
source: str
|
||||
target: str
|
||||
sourceHandle: str
|
||||
targetHandle: str
|
||||
|
||||
|
||||
class SingleNodeGraphDict(TypedDict):
|
||||
nodes: list[_NodeConfigDict]
|
||||
edges: list[_EdgeConfigDict]
|
||||
|
||||
|
||||
class WorkflowEntry:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -318,7 +338,7 @@ class WorkflowEntry:
|
||||
node_data: dict[str, Any],
|
||||
node_width: int = 114,
|
||||
node_height: int = 514,
|
||||
) -> dict[str, Any]:
|
||||
) -> SingleNodeGraphDict:
|
||||
"""
|
||||
Create a minimal graph structure for testing a single node in isolation.
|
||||
|
||||
@@ -328,14 +348,14 @@ class WorkflowEntry:
|
||||
:param node_height: height for UI layout (default: 100)
|
||||
:return: graph dictionary with start node and target node
|
||||
"""
|
||||
node_config = {
|
||||
node_config: _NodeConfigDict = {
|
||||
"id": node_id,
|
||||
"width": node_width,
|
||||
"height": node_height,
|
||||
"type": "custom",
|
||||
"data": node_data,
|
||||
}
|
||||
start_node_config = {
|
||||
start_node_config: _NodeConfigDict = {
|
||||
"id": "start",
|
||||
"width": node_width,
|
||||
"height": node_height,
|
||||
@@ -346,9 +366,9 @@ class WorkflowEntry:
|
||||
"desc": "Start",
|
||||
},
|
||||
}
|
||||
return {
|
||||
"nodes": [start_node_config, node_config],
|
||||
"edges": [
|
||||
return SingleNodeGraphDict(
|
||||
nodes=[start_node_config, node_config],
|
||||
edges=[
|
||||
{
|
||||
"source": "start",
|
||||
"target": node_id,
|
||||
@@ -356,7 +376,7 @@ class WorkflowEntry:
|
||||
"targetHandle": "target",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def run_free_node(
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.nodes.trigger_schedule.entities import SchedulePlanUpdate
|
||||
from events.app_event import app_published_workflow_was_updated
|
||||
@@ -45,7 +45,7 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
|
||||
Returns:
|
||||
Updated or created WorkflowSchedulePlan, or None if no schedule node
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
schedule_config = ScheduleService.extract_schedule_config(workflow)
|
||||
|
||||
existing_plan = session.scalar(
|
||||
@@ -59,7 +59,6 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
|
||||
if existing_plan:
|
||||
logger.info("No schedule node in workflow for app %s, removing schedule plan", app_id)
|
||||
ScheduleService.delete_schedule(session=session, schedule_id=existing_plan.id)
|
||||
session.commit()
|
||||
return None
|
||||
|
||||
if existing_plan:
|
||||
@@ -73,7 +72,6 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
|
||||
schedule_id=existing_plan.id,
|
||||
updates=updates,
|
||||
)
|
||||
session.commit()
|
||||
return updated_plan
|
||||
else:
|
||||
new_plan = ScheduleService.create_schedule(
|
||||
@@ -82,5 +80,4 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
|
||||
app_id=app_id,
|
||||
config=schedule_config,
|
||||
)
|
||||
session.commit()
|
||||
return new_plan
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from events.app_event import app_published_workflow_was_updated
|
||||
@@ -31,7 +31,7 @@ def handle(sender, **kwargs):
|
||||
# Extract trigger info from workflow
|
||||
trigger_infos = get_trigger_infos_from_workflow(published_workflow)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# Get existing app triggers
|
||||
existing_triggers = (
|
||||
session.execute(
|
||||
@@ -79,8 +79,6 @@ def handle(sender, **kwargs):
|
||||
existing_trigger.title = new_title
|
||||
session.add(existing_trigger)
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_trigger_infos_from_workflow(published_workflow: Workflow) -> list[dict]:
|
||||
"""
|
||||
|
||||
@@ -354,11 +354,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
) -> WorkflowRun | None:
|
||||
"""Fallback to PostgreSQL query for records not in LogStore (with tenant isolation)."""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
stmt = select(WorkflowRun).where(
|
||||
WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id, WorkflowRun.app_id == app_id
|
||||
)
|
||||
@@ -439,11 +439,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
def _fallback_get_workflow_run_by_id(self, run_id: str) -> WorkflowRun | None:
|
||||
"""Fallback to PostgreSQL query for records not in LogStore."""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
stmt = select(WorkflowRun).where(WorkflowRun.id == run_id)
|
||||
return session.scalar(stmt)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user