Merge branch 'main' into copilot/fix-fc039c44-aa24-4844-82b2-e9ac237a13e9

This commit is contained in:
Asuka Minato
2025-10-11 03:01:39 +09:00
committed by GitHub
530 changed files with 9769 additions and 5466 deletions

View File

@@ -1,4 +1,4 @@
FROM mcr.microsoft.com/devcontainers/python:3.12-bullseye FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev && apt-get -y install libgmp-dev libmpfr-dev libmpc-dev

View File

@@ -30,6 +30,8 @@ jobs:
run: | run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
# Convert Optional[T] to T | None (ignoring quoted types) # Convert Optional[T] to T | None (ignoring quoted types)
cat > /tmp/optional-rule.yml << 'EOF' cat > /tmp/optional-rule.yml << 'EOF'
id: convert-optional-to-union id: convert-optional-to-union

View File

@@ -18,7 +18,7 @@ jobs:
- name: Deploy to server - name: Deploy to server
uses: appleboy/ssh-action@v0.1.8 uses: appleboy/ssh-action@v0.1.8
with: with:
host: ${{ secrets.RAG_SSH_HOST }} host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }} username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }} key: ${{ secrets.SSH_PRIVATE_KEY }}
script: | script: |

View File

@@ -26,7 +26,6 @@ prepare-web:
@echo "🌐 Setting up web environment..." @echo "🌐 Setting up web environment..."
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists" @cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
@cd web && pnpm install @cd web && pnpm install
@cd web && pnpm build
@echo "✅ Web environment prepared (not started)" @echo "✅ Web environment prepared (not started)"
# Step 3: Prepare API environment # Step 3: Prepare API environment

View File

@@ -40,18 +40,18 @@
<p align="center"> <p align="center">
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a> <a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
<a href="./README/README_TW.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a> <a href="./docs/zh-TW/README.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
<a href="./README/README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a> <a href="./docs/zh-CN/README.md"><img alt="简体中文文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
<a href="./README/README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a> <a href="./docs/ja-JP/README.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
<a href="./README/README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a> <a href="./docs/es-ES/README.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
<a href="./README/README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a> <a href="./docs/fr-FR/README.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
<a href="./README/README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a> <a href="./docs/tlh/README.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
<a href="./README/README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a> <a href="./docs/ko-KR/README.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
<a href="./README/README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a> <a href="./docs/ar-SA/README.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
<a href="./README/README_TR.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a> <a href="./docs/tr-TR/README.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
<a href="./README/README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a> <a href="./docs/vi-VN/README.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
<a href="./README/README_DE.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a> <a href="./docs/de-DE/README.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
<a href="./README/README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a> <a href="./docs/bn-BD/README.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
</p> </p>
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production. Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.

View File

@@ -427,8 +427,8 @@ CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0 CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
CODE_MAX_NUMBER=9223372036854775807 CODE_MAX_NUMBER=9223372036854775807
CODE_MIN_NUMBER=-9223372036854775808 CODE_MIN_NUMBER=-9223372036854775808
CODE_MAX_STRING_LENGTH=80000 CODE_MAX_STRING_LENGTH=400000
TEMPLATE_TRANSFORM_MAX_LENGTH=80000 TEMPLATE_TRANSFORM_MAX_LENGTH=400000
CODE_MAX_STRING_ARRAY_LENGTH=30 CODE_MAX_STRING_ARRAY_LENGTH=30
CODE_MAX_OBJECT_ARRAY_LENGTH=30 CODE_MAX_OBJECT_ARRAY_LENGTH=30
CODE_MAX_NUMBER_ARRAY_LENGTH=1000 CODE_MAX_NUMBER_ARRAY_LENGTH=1000

View File

@@ -81,7 +81,6 @@ ignore = [
"SIM113", # enumerate-for-loop "SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements "SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false "SIM210", # if-expr-with-true-false
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
] ]
[lint.per-file-ignores] [lint.per-file-ignores]

View File

@@ -80,10 +80,10 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. 1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash ```bash
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
``` ```
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal: Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
```bash ```bash
uv run celery -A app.celery beat uv run celery -A app.celery beat

View File

@@ -150,7 +150,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
CODE_MAX_STRING_LENGTH: PositiveInt = Field( CODE_MAX_STRING_LENGTH: PositiveInt = Field(
description="Maximum allowed length for strings in code execution", description="Maximum allowed length for strings in code execution",
default=80000, default=400_000,
) )
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field( CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
@@ -362,11 +362,11 @@ class HttpConfig(BaseSettings):
) )
HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field( HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field(
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60 ge=1, description="Maximum read timeout in seconds for HTTP requests", default=600
) )
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field( HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field(
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20 ge=1, description="Maximum write timeout in seconds for HTTP requests", default=600
) )
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field( HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
@@ -582,6 +582,11 @@ class WorkflowConfig(BaseSettings):
default=200 * 1024, default=200 * 1024,
) )
TEMPLATE_TRANSFORM_MAX_LENGTH: PositiveInt = Field(
description="Maximum number of characters allowed in Template Transform node output",
default=400_000,
)
# GraphEngine Worker Pool Configuration # GraphEngine Worker Pool Configuration
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field( GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
description="Minimum number of workers per GraphEngine instance", description="Minimum number of workers per GraphEngine instance",
@@ -766,7 +771,7 @@ class MailConfig(BaseSettings):
MAIL_TEMPLATING_TIMEOUT: int = Field( MAIL_TEMPLATING_TIMEOUT: int = Field(
description=""" description="""
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates. Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
Only available in sandbox mode.""", Only available in sandbox mode.""",
default=3, default=3,
) )

View File

@@ -1,4 +1,5 @@
from configs import dify_config from configs import dify_config
from libs.collection_utils import convert_to_lower_and_upper_set
HIDDEN_VALUE = "[__HIDDEN__]" HIDDEN_VALUE = "[__HIDDEN__]"
UNKNOWN_VALUE = "[__UNKNOWN__]" UNKNOWN_VALUE = "[__UNKNOWN__]"
@@ -6,24 +7,39 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3 DEFAULT_FILE_NUMBER_LIMITS = 3
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"] VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"] AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
_doc_extensions: set[str]
_doc_extensions: list[str]
if dify_config.ETL_TYPE == "Unstructured": if dify_config.ETL_TYPE == "Unstructured":
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] _doc_extensions = {
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) "txt",
"markdown",
"md",
"mdx",
"pdf",
"html",
"htm",
"xlsx",
"xls",
"vtt",
"properties",
"doc",
"docx",
"csv",
"eml",
"msg",
"pptx",
"xml",
"epub",
}
if dify_config.UNSTRUCTURED_API_URL: if dify_config.UNSTRUCTURED_API_URL:
_doc_extensions.append("ppt") _doc_extensions.add("ppt")
else: else:
_doc_extensions = [ _doc_extensions = {
"txt", "txt",
"markdown", "markdown",
"md", "md",
@@ -37,5 +53,5 @@ else:
"csv", "csv",
"vtt", "vtt",
"properties", "properties",
] }
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions] DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)

View File

@@ -1,31 +1,10 @@
from importlib import import_module
from flask import Blueprint from flask import Blueprint
from flask_restx import Namespace from flask_restx import Namespace
from libs.external_api import ExternalApi from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
from .explore.audio import ChatAudioApi, ChatTextApi
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
from .explore.conversation import (
ConversationApi,
ConversationListApi,
ConversationPinApi,
ConversationRenameApi,
ConversationUnPinApi,
)
from .explore.message import (
MessageFeedbackApi,
MessageListApi,
MessageMoreLikeThisApi,
MessageSuggestedQuestionApi,
)
from .explore.workflow import (
InstalledAppWorkflowRunApi,
InstalledAppWorkflowTaskStopApi,
)
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
bp = Blueprint("console", __name__, url_prefix="/console/api") bp = Blueprint("console", __name__, url_prefix="/console/api")
api = ExternalApi( api = ExternalApi(
@@ -35,23 +14,23 @@ api = ExternalApi(
description="Console management APIs for app configuration, monitoring, and administration", description="Console management APIs for app configuration, monitoring, and administration",
) )
# Create namespace
console_ns = Namespace("console", description="Console management API operations", path="/") console_ns = Namespace("console", description="Console management API operations", path="/")
# File RESOURCE_MODULES = (
api.add_resource(FileApi, "/files/upload") "controllers.console.app.app_import",
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview") "controllers.console.explore.audio",
api.add_resource(FileSupportTypeApi, "/files/support-type") "controllers.console.explore.completion",
"controllers.console.explore.conversation",
"controllers.console.explore.message",
"controllers.console.explore.workflow",
"controllers.console.files",
"controllers.console.remote_files",
)
# Remote files for module_name in RESOURCE_MODULES:
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>") import_module(module_name)
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Import App
api.add_resource(AppImportApi, "/apps/imports")
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
# Ensure resource modules are imported so route decorators are evaluated.
# Import other controllers # Import other controllers
from . import ( from . import (
admin, admin,
@@ -150,77 +129,6 @@ from .workspace import (
workspace, workspace,
) )
# Explore Audio
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
# Explore Completion
api.add_resource(
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
)
api.add_resource(
CompletionStopApi,
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
endpoint="installed_app_stop_completion",
)
api.add_resource(
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
)
api.add_resource(
ChatStopApi,
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
endpoint="installed_app_stop_chat_completion",
)
# Explore Conversation
api.add_resource(
ConversationRenameApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
endpoint="installed_app_conversation_rename",
)
api.add_resource(
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
)
api.add_resource(
ConversationApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
endpoint="installed_app_conversation",
)
api.add_resource(
ConversationPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
endpoint="installed_app_conversation_pin",
)
api.add_resource(
ConversationUnPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
endpoint="installed_app_conversation_unpin",
)
# Explore Message
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
api.add_resource(
MessageFeedbackApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
endpoint="installed_app_message_feedback",
)
api.add_resource(
MessageMoreLikeThisApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
endpoint="installed_app_more_like_this",
)
api.add_resource(
MessageSuggestedQuestionApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="installed_app_suggested_question",
)
# Explore Workflow
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
api.add_resource(
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
)
api.add_namespace(console_ns) api.add_namespace(console_ns)
__all__ = [ __all__ = [

View File

@@ -19,6 +19,7 @@ from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import login_required from libs.login import login_required
from libs.validators import validate_description_length
from models import Account, App from models import Account, App
from services.app_dsl_service import AppDslService, ImportMode from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService from services.app_service import AppService
@@ -28,12 +29,6 @@ from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@console_ns.route("/apps") @console_ns.route("/apps")
class AppListApi(Resource): class AppListApi(Resource):
@api.doc("list_apps") @api.doc("list_apps")
@@ -138,7 +133,7 @@ class AppListApi(Resource):
"""Create app""" """Create app"""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("description", type=validate_description_length, location="json")
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json") parser.add_argument("icon", type=str, location="json")
@@ -219,7 +214,7 @@ class AppApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json") parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json") parser.add_argument("icon_background", type=str, location="json")
@@ -297,7 +292,7 @@ class AppCopyApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, location="json") parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json") parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json") parser.add_argument("icon_background", type=str, location="json")

View File

@@ -20,7 +20,10 @@ from services.app_dsl_service import AppDslService, ImportStatus
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
from .. import console_ns
@console_ns.route("/apps/imports")
class AppImportApi(Resource): class AppImportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -74,6 +77,7 @@ class AppImportApi(Resource):
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
@console_ns.route("/apps/imports/<string:import_id>/confirm")
class AppImportConfirmApi(Resource): class AppImportConfirmApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -98,6 +102,7 @@ class AppImportConfirmApi(Resource):
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
@console_ns.route("/apps/imports/<string:app_id>/check-dependencies")
class AppImportCheckDependenciesApi(Resource): class AppImportCheckDependenciesApi(Resource):
@setup_required @setup_required
@login_required @login_required

View File

@@ -90,7 +90,7 @@ class ModelConfigResource(Resource):
if not isinstance(tool, dict) or len(tool.keys()) <= 3: if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue continue
agent_tool_entity = AgentToolEntity(**tool) agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool # get tool
try: try:
tool_runtime = ToolManager.get_agent_tool_runtime( tool_runtime = ToolManager.get_agent_tool_runtime(
@@ -124,7 +124,7 @@ class ModelConfigResource(Resource):
# encrypt agent tool parameters if it's secret-input # encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get("tools") or []: for tool in agent_mode.get("tools") or []:
agent_tool_entity = AgentToolEntity(**tool) agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool # get tool
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"

View File

@@ -2,7 +2,7 @@ from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError from controllers.console.auth.error import ApiKeyAuthFailedError
from libs.login import login_required from libs.login import login_required
from services.auth.api_key_auth_service import ApiKeyAuthService from services.auth.api_key_auth_service import ApiKeyAuthService
@@ -10,6 +10,7 @@ from services.auth.api_key_auth_service import ApiKeyAuthService
from ..wraps import account_initialization_required, setup_required from ..wraps import account_initialization_required, setup_required
@console_ns.route("/api-key-auth/data-source")
class ApiKeyAuthDataSource(Resource): class ApiKeyAuthDataSource(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -33,6 +34,7 @@ class ApiKeyAuthDataSource(Resource):
return {"sources": []} return {"sources": []}
@console_ns.route("/api-key-auth/data-source/binding")
class ApiKeyAuthDataSourceBinding(Resource): class ApiKeyAuthDataSourceBinding(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -54,6 +56,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/api-key-auth/data-source/<uuid:binding_id>")
class ApiKeyAuthDataSourceBindingDelete(Resource): class ApiKeyAuthDataSourceBindingDelete(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -66,8 +69,3 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
return {"result": "success"}, 204 return {"result": "success"}, 204
api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source")
api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding")
api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/<uuid:binding_id>")

View File

@@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
from controllers.console import api from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
EmailAlreadyInUseError, EmailAlreadyInUseError,
EmailCodeError, EmailCodeError,
@@ -25,6 +25,7 @@ from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.account import AccountNotFoundError, AccountRegisterError
@console_ns.route("/email-register/send-email")
class EmailRegisterSendEmailApi(Resource): class EmailRegisterSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@@ -52,6 +53,7 @@ class EmailRegisterSendEmailApi(Resource):
return {"result": "success", "data": token} return {"result": "success", "data": token}
@console_ns.route("/email-register/validity")
class EmailRegisterCheckApi(Resource): class EmailRegisterCheckApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@@ -92,6 +94,7 @@ class EmailRegisterCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/email-register")
class EmailRegisterResetApi(Resource): class EmailRegisterResetApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@@ -148,8 +151,3 @@ class EmailRegisterResetApi(Resource):
raise AccountInFreezeError() raise AccountInFreezeError()
return account return account
api.add_resource(EmailRegisterSendEmailApi, "/email-register/send-email")
api.add_resource(EmailRegisterCheckApi, "/email-register/validity")
api.add_resource(EmailRegisterResetApi, "/email-register")

View File

@@ -221,8 +221,3 @@ class ForgotPasswordResetApi(Resource):
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant account.current_tenant = tenant
tenant_was_created.send(tenant) tenant_was_created.send(tenant)
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")

View File

@@ -7,7 +7,7 @@ from flask_restx import Resource, reqparse
import services import services
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
from controllers.console import api from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
AuthenticationFailedError, AuthenticationFailedError,
EmailCodeError, EmailCodeError,
@@ -34,6 +34,7 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces
from services.feature_service import FeatureService from services.feature_service import FeatureService
@console_ns.route("/login")
class LoginApi(Resource): class LoginApi(Resource):
"""Resource for user login.""" """Resource for user login."""
@@ -91,6 +92,7 @@ class LoginApi(Resource):
return {"result": "success", "data": token_pair.model_dump()} return {"result": "success", "data": token_pair.model_dump()}
@console_ns.route("/logout")
class LogoutApi(Resource): class LogoutApi(Resource):
@setup_required @setup_required
def get(self): def get(self):
@@ -102,6 +104,7 @@ class LogoutApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/reset-password")
class ResetPasswordSendEmailApi(Resource): class ResetPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@@ -130,6 +133,7 @@ class ResetPasswordSendEmailApi(Resource):
return {"result": "success", "data": token} return {"result": "success", "data": token}
@console_ns.route("/email-code-login")
class EmailCodeLoginSendEmailApi(Resource): class EmailCodeLoginSendEmailApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
@@ -162,6 +166,7 @@ class EmailCodeLoginSendEmailApi(Resource):
return {"result": "success", "data": token} return {"result": "success", "data": token}
@console_ns.route("/email-code-login/validity")
class EmailCodeLoginApi(Resource): class EmailCodeLoginApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
@@ -218,6 +223,7 @@ class EmailCodeLoginApi(Resource):
return {"result": "success", "data": token_pair.model_dump()} return {"result": "success", "data": token_pair.model_dump()}
@console_ns.route("/refresh-token")
class RefreshTokenApi(Resource): class RefreshTokenApi(Resource):
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@@ -229,11 +235,3 @@ class RefreshTokenApi(Resource):
return {"result": "success", "data": new_token_pair.model_dump()} return {"result": "success", "data": new_token_pair.model_dump()}
except Exception as e: except Exception as e:
return {"result": "fail", "data": str(e)}, 401 return {"result": "fail", "data": str(e)}, 401
api.add_resource(LoginApi, "/login")
api.add_resource(LogoutApi, "/logout")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
api.add_resource(ResetPasswordSendEmailApi, "/reset-password")
api.add_resource(RefreshTokenApi, "/refresh-token")

View File

@@ -14,7 +14,7 @@ from models.account import Account
from models.model import OAuthProviderApp from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
from .. import api from .. import console_ns
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
@@ -86,6 +86,7 @@ def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProvid
return decorated return decorated
@console_ns.route("/oauth/provider")
class OAuthServerAppApi(Resource): class OAuthServerAppApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
@@ -108,6 +109,7 @@ class OAuthServerAppApi(Resource):
) )
@console_ns.route("/oauth/provider/authorize")
class OAuthServerUserAuthorizeApi(Resource): class OAuthServerUserAuthorizeApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -125,6 +127,7 @@ class OAuthServerUserAuthorizeApi(Resource):
) )
@console_ns.route("/oauth/provider/token")
class OAuthServerUserTokenApi(Resource): class OAuthServerUserTokenApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
@@ -180,6 +183,7 @@ class OAuthServerUserTokenApi(Resource):
) )
@console_ns.route("/oauth/provider/account")
class OAuthServerUserAccountApi(Resource): class OAuthServerUserAccountApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
@@ -194,9 +198,3 @@ class OAuthServerUserAccountApi(Resource):
"timezone": account.timezone, "timezone": account.timezone,
} }
) )
api.add_resource(OAuthServerAppApi, "/oauth/provider")
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")

View File

@@ -1,12 +1,13 @@
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models.model import Account from models.model import Account
from services.billing_service import BillingService from services.billing_service import BillingService
@console_ns.route("/billing/subscription")
class Subscription(Resource): class Subscription(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -26,6 +27,7 @@ class Subscription(Resource):
) )
@console_ns.route("/billing/invoices")
class Invoices(Resource): class Invoices(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -36,7 +38,3 @@ class Invoices(Resource):
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None assert current_user.current_tenant_id is not None
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
api.add_resource(Subscription, "/billing/subscription")
api.add_resource(Invoices, "/billing/invoices")

View File

@@ -6,10 +6,11 @@ from libs.helper import extract_remote_ip
from libs.login import login_required from libs.login import login_required
from services.billing_service import BillingService from services.billing_service import BillingService
from .. import api from .. import console_ns
from ..wraps import account_initialization_required, only_edition_cloud, setup_required from ..wraps import account_initialization_required, only_edition_cloud, setup_required
@console_ns.route("/compliance/download")
class ComplianceApi(Resource): class ComplianceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -30,6 +31,3 @@ class ComplianceApi(Resource):
ip=ip_address, ip=ip_address,
device_info=device_info, device_info=device_info,
) )
api.add_resource(ComplianceApi, "/compliance/download")

View File

@@ -9,13 +9,13 @@ from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
@@ -27,6 +27,10 @@ from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task from tasks.document_indexing_sync_task import document_indexing_sync_task
@console_ns.route(
"/data-source/integrates",
"/data-source/integrates/<uuid:binding_id>/<string:action>",
)
class DataSourceApi(Resource): class DataSourceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -109,6 +113,7 @@ class DataSourceApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/notion/pre-import/pages")
class DataSourceNotionListApi(Resource): class DataSourceNotionListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -196,6 +201,10 @@ class DataSourceNotionListApi(Resource):
return {"notion_info": {**workspace_info, "pages": pages}}, 200 return {"notion_info": {**workspace_info, "pages": pages}}, 200
@console_ns.route(
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
"/datasets/notion-indexing-estimate",
)
class DataSourceNotionApi(Resource): class DataSourceNotionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -248,13 +257,15 @@ class DataSourceNotionApi(Resource):
for page in notion_info["pages"]: for page in notion_info["pages"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION.value,
notion_info={ notion_info=NotionInfo.model_validate(
"credential_id": credential_id, {
"notion_workspace_id": workspace_id, "credential_id": credential_id,
"notion_obj_id": page["page_id"], "notion_workspace_id": workspace_id,
"notion_page_type": page["type"], "notion_obj_id": page["page_id"],
"tenant_id": current_user.current_tenant_id, "notion_page_type": page["type"],
}, "tenant_id": current_user.current_tenant_id,
}
),
document_model=args["doc_form"], document_model=args["doc_form"],
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@@ -269,6 +280,7 @@ class DataSourceNotionApi(Resource):
return response.model_dump(), 200 return response.model_dump(), 200
@console_ns.route("/datasets/<uuid:dataset_id>/notion/sync")
class DataSourceNotionDatasetSyncApi(Resource): class DataSourceNotionDatasetSyncApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -285,6 +297,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync")
class DataSourceNotionDocumentSyncApi(Resource): class DataSourceNotionDocumentSyncApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -301,16 +314,3 @@ class DataSourceNotionDocumentSyncApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
document_indexing_sync_task.delay(dataset_id_str, document_id_str) document_indexing_sync_task.delay(dataset_id_str, document_id_str)
return {"result": "success"}, 200 return {"result": "success"}, 200
api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")
api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages")
api.add_resource(
DataSourceNotionApi,
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
"/datasets/notion-indexing-estimate",
)
api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync")
api.add_resource(
DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync"
)

View File

@@ -1,4 +1,5 @@
import flask_restx from typing import Any, cast
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
@@ -23,31 +24,27 @@ from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import related_app_list from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields from fields.document_fields import document_status_fields
from libs.login import login_required from libs.login import login_required
from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
def _validate_name(name): def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40: if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.") raise ValueError("Name must be between 1 to 40 characters.")
return name return name
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@console_ns.route("/datasets") @console_ns.route("/datasets")
class DatasetListApi(Resource): class DatasetListApi(Resource):
@api.doc("get_datasets") @api.doc("get_datasets")
@@ -92,7 +89,7 @@ class DatasetListApi(Resource):
for embedding_model in embedding_models: for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = marshal(datasets, dataset_detail_fields) data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
for item in data: for item in data:
# convert embedding_model_provider to plugin standard format # convert embedding_model_provider to plugin standard format
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
@@ -147,7 +144,7 @@ class DatasetListApi(Resource):
) )
parser.add_argument( parser.add_argument(
"description", "description",
type=_validate_description_length, type=validate_description_length,
nullable=True, nullable=True,
required=False, required=False,
default="", default="",
@@ -192,7 +189,7 @@ class DatasetListApi(Resource):
name=args["name"], name=args["name"],
description=args["description"], description=args["description"],
indexing_technique=args["indexing_technique"], indexing_technique=args["indexing_technique"],
account=current_user, account=cast(Account, current_user),
permission=DatasetPermissionEnum.ONLY_ME, permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"], provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"], external_knowledge_api_id=args["external_knowledge_api_id"],
@@ -224,7 +221,7 @@ class DatasetApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
if dataset.embedding_model_provider: if dataset.embedding_model_provider:
provider_id = ModelProviderID(dataset.embedding_model_provider) provider_id = ModelProviderID(dataset.embedding_model_provider)
@@ -288,7 +285,7 @@ class DatasetApi(Resource):
help="type is required. Name must be between 1 to 40 characters.", help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name, type=_validate_name,
) )
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) parser.add_argument("description", location="json", store_missing=False, type=validate_description_length)
parser.add_argument( parser.add_argument(
"indexing_technique", "indexing_technique",
type=str, type=str,
@@ -369,7 +366,7 @@ class DatasetApi(Resource):
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
result_data = marshal(dataset, dataset_detail_fields) result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members": if data.get("partial_member_list") and data.get("permission") == "partial_members":
@@ -516,13 +513,15 @@ class DatasetIndexingEstimateApi(Resource):
for page in notion_info["pages"]: for page in notion_info["pages"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION.value,
notion_info={ notion_info=NotionInfo.model_validate(
"credential_id": credential_id, {
"notion_workspace_id": workspace_id, "credential_id": credential_id,
"notion_obj_id": page["page_id"], "notion_workspace_id": workspace_id,
"notion_page_type": page["type"], "notion_obj_id": page["page_id"],
"tenant_id": current_user.current_tenant_id, "notion_page_type": page["type"],
}, "tenant_id": current_user.current_tenant_id,
}
),
document_model=args["doc_form"], document_model=args["doc_form"],
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@@ -531,14 +530,16 @@ class DatasetIndexingEstimateApi(Resource):
for url in website_info_list["urls"]: for url in website_info_list["urls"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value, datasource_type=DatasourceType.WEBSITE.value,
website_info={ website_info=WebsiteInfo.model_validate(
"provider": website_info_list["provider"], {
"job_id": website_info_list["job_id"], "provider": website_info_list["provider"],
"url": url, "job_id": website_info_list["job_id"],
"tenant_id": current_user.current_tenant_id, "url": url,
"mode": "crawl", "tenant_id": current_user.current_tenant_id,
"only_main_content": website_info_list["only_main_content"], "mode": "crawl",
}, "only_main_content": website_info_list["only_main_content"],
}
),
document_model=args["doc_form"], document_model=args["doc_form"],
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@@ -688,7 +689,7 @@ class DatasetApiKeyApi(Resource):
) )
if current_key_count >= self.max_keys: if current_key_count >= self.max_keys:
flask_restx.abort( api.abort(
400, 400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.", message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded", code="max_keys_exceeded",
@@ -733,7 +734,7 @@ class DatasetApiDeleteApi(Resource):
) )
if key is None: if key is None:
flask_restx.abort(404, message="API key not found") api.abort(404, message="API key not found")
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit() db.session.commit()

View File

@@ -44,7 +44,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from extensions.ext_database import db from extensions.ext_database import db
from fields.document_fields import ( from fields.document_fields import (
dataset_and_document_fields, dataset_and_document_fields,
@@ -55,6 +55,7 @@ from fields.document_fields import (
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DocumentPipelineExecutionLog from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
@@ -304,7 +305,7 @@ class DatasetDocumentListApi(Resource):
"doc_language", type=str, default="English", required=False, nullable=False, location="json" "doc_language", type=str, default="English", required=False, nullable=False, location="json"
) )
args = parser.parse_args() args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig.model_validate(args)
if not dataset.indexing_technique and not knowledge_config.indexing_technique: if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
@@ -394,7 +395,7 @@ class DatasetInitApi(Resource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig.model_validate(args)
if knowledge_config.indexing_technique == "high_quality": if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.") raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
@@ -418,7 +419,9 @@ class DatasetInitApi(Resource):
try: try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id( dataset, documents, batch = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user tenant_id=current_user.current_tenant_id,
knowledge_config=knowledge_config,
account=cast(Account, current_user),
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@@ -452,7 +455,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule data_process_rule = document.dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict() data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
@@ -514,7 +517,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not documents: if not documents:
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200 return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
data_process_rule = documents[0].dataset_process_rule data_process_rule = documents[0].dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict() data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
extract_settings = [] extract_settings = []
for document in documents: for document in documents:
if document.indexing_status in {"completed", "error"}: if document.indexing_status in {"completed", "error"}:
@@ -544,13 +547,15 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
continue continue
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION.value,
notion_info={ notion_info=NotionInfo.model_validate(
"credential_id": data_source_info["credential_id"], {
"notion_workspace_id": data_source_info["notion_workspace_id"], "credential_id": data_source_info["credential_id"],
"notion_obj_id": data_source_info["notion_page_id"], "notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_page_type": data_source_info["type"], "notion_obj_id": data_source_info["notion_page_id"],
"tenant_id": current_user.current_tenant_id, "notion_page_type": data_source_info["type"],
}, "tenant_id": current_user.current_tenant_id,
}
),
document_model=document.doc_form, document_model=document.doc_form,
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@@ -559,14 +564,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
continue continue
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value, datasource_type=DatasourceType.WEBSITE.value,
website_info={ website_info=WebsiteInfo.model_validate(
"provider": data_source_info["provider"], {
"job_id": data_source_info["job_id"], "provider": data_source_info["provider"],
"url": data_source_info["url"], "job_id": data_source_info["job_id"],
"tenant_id": current_user.current_tenant_id, "url": data_source_info["url"],
"mode": data_source_info["mode"], "tenant_id": current_user.current_tenant_id,
"only_main_content": data_source_info["only_main_content"], "mode": data_source_info["mode"],
}, "only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form, document_model=document.doc_form,
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@@ -753,7 +760,7 @@ class DocumentApi(DocumentResource):
} }
else: else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict data_source_info = document.data_source_detail_dict
response = { response = {
"id": document.id, "id": document.id,
@@ -1073,7 +1080,9 @@ class DocumentRenameApi(DocumentResource):
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
DatasetService.check_dataset_operator_permission(current_user, dataset) if not dataset:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
@@ -1114,6 +1123,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log")
class DocumentPipelineExecutionLogApi(DocumentResource): class DocumentPipelineExecutionLogApi(DocumentResource):
@setup_required @setup_required
@login_required @login_required
@@ -1147,29 +1157,3 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
"input_data": log.input_data, "input_data": log.input_data,
"datasource_node_id": log.datasource_node_id, "datasource_node_id": log.datasource_node_id,
}, 200 }, 200
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
api.add_resource(DatasetInitApi, "/datasets/init")
api.add_resource(
DocumentIndexingEstimateApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate"
)
api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(
DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
)
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
api.add_resource(
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
)

View File

@@ -7,7 +7,7 @@ from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.console import api from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import ( from controllers.console.datasets.error import (
ChildChunkDeleteIndexError, ChildChunkDeleteIndexError,
@@ -37,6 +37,7 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
class DatasetDocumentSegmentListApi(Resource): class DatasetDocumentSegmentListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -139,6 +140,7 @@ class DatasetDocumentSegmentListApi(Resource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
class DatasetDocumentSegmentApi(Resource): class DatasetDocumentSegmentApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -193,6 +195,7 @@ class DatasetDocumentSegmentApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
class DatasetDocumentSegmentAddApi(Resource): class DatasetDocumentSegmentAddApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -244,6 +247,7 @@ class DatasetDocumentSegmentAddApi(Resource):
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
class DatasetDocumentSegmentUpdateApi(Resource): class DatasetDocumentSegmentUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -305,7 +309,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
) )
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document) SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset) segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required @setup_required
@@ -345,6 +349,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
"/datasets/batch_import_status/<uuid:job_id>",
)
class DatasetDocumentSegmentBatchImportApi(Resource): class DatasetDocumentSegmentBatchImportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -384,7 +392,12 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
# send batch add segments task # send batch add segments task
redis_client.setnx(indexing_cache_key, "waiting") redis_client.setnx(indexing_cache_key, "waiting")
batch_create_segment_to_index_task.delay( batch_create_segment_to_index_task.delay(
str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id str(job_id),
upload_file_id,
dataset_id,
document_id,
current_user.current_tenant_id,
current_user.id,
) )
except Exception as e: except Exception as e:
return {"error": str(e)}, 500 return {"error": str(e)}, 500
@@ -393,7 +406,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, job_id): def get(self, job_id=None, dataset_id=None, document_id=None):
if job_id is None:
raise NotFound("The job does not exist.")
job_id = str(job_id) job_id = str(job_id)
indexing_cache_key = f"segment_batch_import_{job_id}" indexing_cache_key = f"segment_batch_import_{job_id}"
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
@@ -403,6 +418,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
return {"job_id": job_id, "job_status": cache_result.decode()}, 200 return {"job_id": job_id, "job_status": cache_result.decode()}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks")
class ChildChunkAddApi(Resource): class ChildChunkAddApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -457,7 +473,8 @@ class ChildChunkAddApi(Resource):
parser.add_argument("content", type=str, required=True, nullable=False, location="json") parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) content = args["content"]
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200 return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@@ -546,13 +563,17 @@ class ChildChunkAddApi(Resource):
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json") parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")] chunks_data = args["chunks"]
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200 return {"data": marshal(child_chunks, child_chunk_fields)}, 200
@console_ns.route(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>"
)
class ChildChunkUpdateApi(Resource): class ChildChunkUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -660,33 +681,8 @@ class ChildChunkUpdateApi(Resource):
parser.add_argument("content", type=str, required=True, nullable=False, location="json") parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
child_chunk = SegmentService.update_child_chunk( content = args["content"]
args.get("content"), child_chunk, segment, document, dataset child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200 return {"data": marshal(child_chunk, child_chunk_fields)}, 200
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
api.add_resource(
DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
)
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
api.add_resource(
DatasetDocumentSegmentUpdateApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
)
api.add_resource(
DatasetDocumentSegmentBatchImportApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
"/datasets/batch_import_status/<uuid:job_id>",
)
api.add_resource(
ChildChunkAddApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
)
api.add_resource(
ChildChunkUpdateApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
)

View File

@@ -1,3 +1,5 @@
from typing import cast
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields, marshal, reqparse from flask_restx import Resource, fields, marshal, reqparse
@@ -9,13 +11,14 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required from libs.login import login_required
from models.account import Account
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService from services.knowledge_service import ExternalDatasetTestService
def _validate_name(name): def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 100: if not name or len(name) < 1 or len(name) > 100:
raise ValueError("Name must be between 1 to 100 characters.") raise ValueError("Name must be between 1 to 100 characters.")
return name return name
@@ -274,7 +277,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
response = HitTestingService.external_retrieve( response = HitTestingService.external_retrieve(
dataset=dataset, dataset=dataset,
query=args["query"], query=args["query"],
account=current_user, account=cast(Account, current_user),
external_retrieval_model=args["external_retrieval_model"], external_retrieval_model=args["external_retrieval_model"],
metadata_filtering_conditions=args["metadata_filtering_conditions"], metadata_filtering_conditions=args["metadata_filtering_conditions"],
) )

View File

@@ -1,10 +1,11 @@
import logging import logging
from typing import cast
from flask_login import current_user from flask_login import current_user
from flask_restx import marshal, reqparse from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services.dataset_service import services
from controllers.console.app.error import ( from controllers.console.app.error import (
CompletionRequestError, CompletionRequestError,
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
@@ -20,6 +21,7 @@ from core.errors.error import (
) )
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields from fields.hit_testing_fields import hit_testing_record_fields
from models.account import Account
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
@@ -59,7 +61,7 @@ class DatasetsHitTestingBase:
response = HitTestingService.retrieve( response = HitTestingService.retrieve(
dataset=dataset, dataset=dataset,
query=args["query"], query=args["query"],
account=current_user, account=cast(Account, current_user),
retrieval_model=args["retrieval_model"], retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"], external_retrieval_model=args["external_retrieval_model"],
limit=10, limit=10,

View File

@@ -4,7 +4,7 @@ from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields from fields.dataset_fields import dataset_metadata_fields
from libs.login import login_required from libs.login import login_required
@@ -16,6 +16,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
class DatasetMetadataCreateApi(Resource): class DatasetMetadataCreateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -27,7 +28,7 @@ class DatasetMetadataCreateApi(Resource):
parser.add_argument("type", type=str, required=True, nullable=False, location="json") parser.add_argument("type", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
metadata_args = MetadataArgs(**args) metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@@ -50,6 +51,7 @@ class DatasetMetadataCreateApi(Resource):
return MetadataService.get_dataset_metadatas(dataset), 200 return MetadataService.get_dataset_metadatas(dataset), 200
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
class DatasetMetadataApi(Resource): class DatasetMetadataApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -60,6 +62,7 @@ class DatasetMetadataApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
name = args["name"]
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id) metadata_id_str = str(metadata_id)
@@ -68,7 +71,7 @@ class DatasetMetadataApi(Resource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
return metadata, 200 return metadata, 200
@setup_required @setup_required
@@ -87,6 +90,7 @@ class DatasetMetadataApi(Resource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route("/datasets/metadata/built-in")
class DatasetMetadataBuiltInFieldApi(Resource): class DatasetMetadataBuiltInFieldApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -97,6 +101,7 @@ class DatasetMetadataBuiltInFieldApi(Resource):
return {"fields": built_in_fields}, 200 return {"fields": built_in_fields}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
class DatasetMetadataBuiltInFieldActionApi(Resource): class DatasetMetadataBuiltInFieldActionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -116,6 +121,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
class DocumentMetadataEditApi(Resource): class DocumentMetadataEditApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -131,15 +137,8 @@ class DocumentMetadataEditApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json") parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
metadata_args = MetadataOperationData(**args) metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args) MetadataService.update_documents_metadata(dataset, metadata_args)
return {"result": "success"}, 200 return {"result": "success"}, 200
api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
api.add_resource(DocumentMetadataEditApi, "/datasets/<uuid:dataset_id>/documents/metadata")

View File

@@ -1,16 +1,16 @@
from fastapi.encoders import jsonable_encoder
from flask import make_response, redirect, request from flask import make_response, redirect, request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
setup_required, setup_required,
) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen from libs.helper import StrLen
from libs.login import login_required from libs.login import login_required
@@ -19,6 +19,7 @@ from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService from services.plugin.oauth_service import OAuthProxyService
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
class DatasourcePluginOAuthAuthorizationUrl(Resource): class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -68,6 +69,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
return response return response
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/callback")
class DatasourceOAuthCallback(Resource): class DatasourceOAuthCallback(Resource):
@setup_required @setup_required
def get(self, provider_id: str): def get(self, provider_id: str):
@@ -123,6 +125,7 @@ class DatasourceOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
class DatasourceAuth(Resource): class DatasourceAuth(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -165,6 +168,7 @@ class DatasourceAuth(Resource):
return {"result": datasources}, 200 return {"result": datasources}, 200
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
class DatasourceAuthDeleteApi(Resource): class DatasourceAuthDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -188,6 +192,7 @@ class DatasourceAuthDeleteApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
class DatasourceAuthUpdateApi(Resource): class DatasourceAuthUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -213,6 +218,7 @@ class DatasourceAuthUpdateApi(Resource):
return {"result": "success"}, 201 return {"result": "success"}, 201
@console_ns.route("/auth/plugin/datasource/list")
class DatasourceAuthListApi(Resource): class DatasourceAuthListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -225,6 +231,7 @@ class DatasourceAuthListApi(Resource):
return {"result": jsonable_encoder(datasources)}, 200 return {"result": jsonable_encoder(datasources)}, 200
@console_ns.route("/auth/plugin/datasource/default-list")
class DatasourceHardCodeAuthListApi(Resource): class DatasourceHardCodeAuthListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -237,6 +244,7 @@ class DatasourceHardCodeAuthListApi(Resource):
return {"result": jsonable_encoder(datasources)}, 200 return {"result": jsonable_encoder(datasources)}, 200
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
class DatasourceAuthOauthCustomClient(Resource): class DatasourceAuthOauthCustomClient(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -271,6 +279,7 @@ class DatasourceAuthOauthCustomClient(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
class DatasourceAuthDefaultApi(Resource): class DatasourceAuthDefaultApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -291,6 +300,7 @@ class DatasourceAuthDefaultApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
class DatasourceUpdateProviderNameApi(Resource): class DatasourceUpdateProviderNameApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -311,52 +321,3 @@ class DatasourceUpdateProviderNameApi(Resource):
credential_id=args["credential_id"], credential_id=args["credential_id"],
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
api.add_resource(
DatasourcePluginOAuthAuthorizationUrl,
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
)
api.add_resource(
DatasourceOAuthCallback,
"/oauth/plugin/<path:provider_id>/datasource/callback",
)
api.add_resource(
DatasourceAuth,
"/auth/plugin/datasource/<path:provider_id>",
)
api.add_resource(
DatasourceAuthUpdateApi,
"/auth/plugin/datasource/<path:provider_id>/update",
)
api.add_resource(
DatasourceAuthDeleteApi,
"/auth/plugin/datasource/<path:provider_id>/delete",
)
api.add_resource(
DatasourceAuthListApi,
"/auth/plugin/datasource/list",
)
api.add_resource(
DatasourceHardCodeAuthListApi,
"/auth/plugin/datasource/default-list",
)
api.add_resource(
DatasourceAuthOauthCustomClient,
"/auth/plugin/datasource/<path:provider_id>/custom-client",
)
api.add_resource(
DatasourceAuthDefaultApi,
"/auth/plugin/datasource/<path:provider_id>/default",
)
api.add_resource(
DatasourceUpdateProviderNameApi,
"/auth/plugin/datasource/<path:provider_id>/update-name",
)

View File

@@ -4,7 +4,7 @@ from flask_restx import ( # type: ignore
) )
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_user, login_required from libs.login import current_user, login_required
@@ -13,6 +13,7 @@ from models.dataset import Pipeline
from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline import RagPipelineService
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource): class DataSourceContentPreviewApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -49,9 +50,3 @@ class DataSourceContentPreviewApi(Resource):
credential_id=args.get("credential_id"), credential_id=args.get("credential_id"),
) )
return preview_content, 200 return preview_content, 200
api.add_resource(
DataSourceContentPreviewApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
)

View File

@@ -4,7 +4,7 @@ from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
enterprise_license_required, enterprise_license_required,
@@ -20,18 +20,19 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _validate_name(name): def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40: if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.") raise ValueError("Name must be between 1 to 40 characters.")
return name return name
def _validate_description_length(description): def _validate_description_length(description: str) -> str:
if len(description) > 400: if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.") raise ValueError("Description cannot exceed 400 characters.")
return description return description
@console_ns.route("/rag/pipeline/templates")
class PipelineTemplateListApi(Resource): class PipelineTemplateListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -45,6 +46,7 @@ class PipelineTemplateListApi(Resource):
return pipeline_templates, 200 return pipeline_templates, 200
@console_ns.route("/rag/pipeline/templates/<string:template_id>")
class PipelineTemplateDetailApi(Resource): class PipelineTemplateDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -57,6 +59,7 @@ class PipelineTemplateDetailApi(Resource):
return pipeline_template, 200 return pipeline_template, 200
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
class CustomizedPipelineTemplateApi(Resource): class CustomizedPipelineTemplateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -73,7 +76,7 @@ class CustomizedPipelineTemplateApi(Resource):
) )
parser.add_argument( parser.add_argument(
"description", "description",
type=str, type=_validate_description_length,
nullable=True, nullable=True,
required=False, required=False,
default="", default="",
@@ -85,7 +88,7 @@ class CustomizedPipelineTemplateApi(Resource):
nullable=True, nullable=True,
) )
args = parser.parse_args() args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args) pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200 return 200
@@ -112,6 +115,7 @@ class CustomizedPipelineTemplateApi(Resource):
return {"data": template.yaml_content}, 200 return {"data": template.yaml_content}, 200
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
class PublishCustomizedPipelineTemplateApi(Resource): class PublishCustomizedPipelineTemplateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -129,7 +133,7 @@ class PublishCustomizedPipelineTemplateApi(Resource):
) )
parser.add_argument( parser.add_argument(
"description", "description",
type=str, type=_validate_description_length,
nullable=True, nullable=True,
required=False, required=False,
default="", default="",
@@ -144,21 +148,3 @@ class PublishCustomizedPipelineTemplateApi(Resource):
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args) rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
return {"result": "success"} return {"result": "success"}
api.add_resource(
PipelineTemplateListApi,
"/rag/pipeline/templates",
)
api.add_resource(
PipelineTemplateDetailApi,
"/rag/pipeline/templates/<string:template_id>",
)
api.add_resource(
CustomizedPipelineTemplateApi,
"/rag/pipeline/customized/templates/<string:template_id>",
)
api.add_resource(
PublishCustomizedPipelineTemplateApi,
"/rag/pipelines/<string:pipeline_id>/customized/publish",
)

View File

@@ -1,10 +1,10 @@
from flask_login import current_user # type: ignore # type: ignore from flask_login import current_user
from flask_restx import Resource, marshal, reqparse # type: ignore from flask_restx import Resource, marshal, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
import services import services
from controllers.console import api from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
@@ -20,18 +20,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
def _validate_name(name): @console_ns.route("/rag/pipeline/dataset")
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class CreateRagPipelineDatasetApi(Resource): class CreateRagPipelineDatasetApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -84,6 +73,7 @@ class CreateRagPipelineDatasetApi(Resource):
return import_info, 201 return import_info, 201
@console_ns.route("/rag/pipeline/empty-dataset")
class CreateEmptyRagPipelineDatasetApi(Resource): class CreateEmptyRagPipelineDatasetApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -108,7 +98,3 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
), ),
) )
return marshal(dataset, dataset_detail_fields), 201 return marshal(dataset, dataset_detail_fields), 201
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")

View File

@@ -1,24 +1,22 @@
import logging import logging
from typing import Any, NoReturn from typing import NoReturn
from flask import Response from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
DraftWorkflowNotExist, DraftWorkflowNotExist,
) )
from controllers.console.app.workflow_draft_variable import ( from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS, _WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage]
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage]
) )
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db from extensions.ext_database import db
@@ -34,32 +32,6 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
return value.value
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
value = variable.get_value()
# create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True)
# Refresh the url signature before returning it to client.
if isinstance(value, FileSegment):
file = value.value
file.remote_url = file.generate_url()
elif isinstance(value, ArrayFileSegment):
files = value.value
for file in files:
file.remote_url = file.generate_url()
return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser(): def _create_pagination_parser():
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument( parser.add_argument(
@@ -104,13 +76,14 @@ def _api_prerequisite(f):
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not isinstance(current_user, Account) or not current_user.is_editor: if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
return f(*args, **kwargs) return f(*args, **kwargs)
return wrapper return wrapper
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables")
class RagPipelineVariableCollectionApi(Resource): class RagPipelineVariableCollectionApi(Resource):
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
@@ -168,6 +141,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
return None return None
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables")
class RagPipelineNodeVariableCollectionApi(Resource): class RagPipelineNodeVariableCollectionApi(Resource):
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@@ -190,6 +164,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
return Response("", 204) return Response("", 204)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>")
class RagPipelineVariableApi(Resource): class RagPipelineVariableApi(Resource):
_PATCH_NAME_FIELD = "name" _PATCH_NAME_FIELD = "name"
_PATCH_VALUE_FIELD = "value" _PATCH_VALUE_FIELD = "value"
@@ -284,6 +259,7 @@ class RagPipelineVariableApi(Resource):
return Response("", 204) return Response("", 204)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
class RagPipelineVariableResetApi(Resource): class RagPipelineVariableResetApi(Resource):
@_api_prerequisite @_api_prerequisite
def put(self, pipeline: Pipeline, variable_id: str): def put(self, pipeline: Pipeline, variable_id: str):
@@ -325,6 +301,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
return draft_vars return draft_vars
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables")
class RagPipelineSystemVariableCollectionApi(Resource): class RagPipelineSystemVariableCollectionApi(Resource):
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@@ -332,6 +309,7 @@ class RagPipelineSystemVariableCollectionApi(Resource):
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID) return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables")
class RagPipelineEnvironmentVariableCollectionApi(Resource): class RagPipelineEnvironmentVariableCollectionApi(Resource):
@_api_prerequisite @_api_prerequisite
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
@@ -364,26 +342,3 @@ class RagPipelineEnvironmentVariableCollectionApi(Resource):
) )
return {"items": env_vars_list} return {"items": env_vars_list}
api.add_resource(
RagPipelineVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables",
)
api.add_resource(
RagPipelineNodeVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables",
)
api.add_resource(
RagPipelineVariableApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>"
)
api.add_resource(
RagPipelineVariableResetApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset"
)
api.add_resource(
RagPipelineSystemVariableCollectionApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables"
)
api.add_resource(
RagPipelineEnvironmentVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables",
)

View File

@@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
@@ -20,6 +20,7 @@ from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource): class RagPipelineImportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -66,6 +67,7 @@ class RagPipelineImportApi(Resource):
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
class RagPipelineImportConfirmApi(Resource): class RagPipelineImportConfirmApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -90,6 +92,7 @@ class RagPipelineImportConfirmApi(Resource):
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/imports/<string:pipeline_id>/check-dependencies")
class RagPipelineImportCheckDependenciesApi(Resource): class RagPipelineImportCheckDependenciesApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -107,6 +110,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/<string:pipeline_id>/exports")
class RagPipelineExportApi(Resource): class RagPipelineExportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -128,22 +132,3 @@ class RagPipelineExportApi(Resource):
) )
return {"data": result}, 200 return {"data": result}, 200
# Import Rag Pipeline
api.add_resource(
RagPipelineImportApi,
"/rag/pipelines/imports",
)
api.add_resource(
RagPipelineImportConfirmApi,
"/rag/pipelines/imports/<string:import_id>/confirm",
)
api.add_resource(
RagPipelineImportCheckDependenciesApi,
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
)
api.add_resource(
RagPipelineExportApi,
"/rag/pipelines/<string:pipeline_id>/exports",
)

View File

@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from controllers.console import api from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
ConversationCompletedError, ConversationCompletedError,
DraftWorkflowNotExist, DraftWorkflowNotExist,
@@ -50,6 +50,7 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
class DraftRagPipelineApi(Resource): class DraftRagPipelineApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -147,6 +148,7 @@ class DraftRagPipelineApi(Resource):
} }
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class RagPipelineDraftRunIterationNodeApi(Resource): class RagPipelineDraftRunIterationNodeApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -181,6 +183,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
raise InternalServerError() raise InternalServerError()
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class RagPipelineDraftRunLoopNodeApi(Resource): class RagPipelineDraftRunLoopNodeApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -215,6 +218,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
raise InternalServerError() raise InternalServerError()
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
class DraftRagPipelineRunApi(Resource): class DraftRagPipelineRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -249,6 +253,7 @@ class DraftRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description) raise InvokeRateLimitHttpError(ex.description)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
class PublishedRagPipelineRunApi(Resource): class PublishedRagPipelineRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -369,6 +374,7 @@ class PublishedRagPipelineRunApi(Resource):
# #
# return result # return result
# #
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource): class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -411,6 +417,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
) )
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
class RagPipelineDraftDatasourceNodeRunApi(Resource): class RagPipelineDraftDatasourceNodeRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -453,6 +460,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
) )
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
class RagPipelineDraftNodeRunApi(Resource): class RagPipelineDraftNodeRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -486,6 +494,7 @@ class RagPipelineDraftNodeRunApi(Resource):
return workflow_node_execution return workflow_node_execution
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop")
class RagPipelineTaskStopApi(Resource): class RagPipelineTaskStopApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -504,6 +513,7 @@ class RagPipelineTaskStopApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/publish")
class PublishedRagPipelineApi(Resource): class PublishedRagPipelineApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -559,6 +569,7 @@ class PublishedRagPipelineApi(Resource):
} }
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs")
class DefaultRagPipelineBlockConfigsApi(Resource): class DefaultRagPipelineBlockConfigsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -577,6 +588,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
return rag_pipeline_service.get_default_block_configs() return rag_pipeline_service.get_default_block_configs()
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultRagPipelineBlockConfigApi(Resource): class DefaultRagPipelineBlockConfigApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -608,6 +620,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
class PublishedAllRagPipelineApi(Resource): class PublishedAllRagPipelineApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -656,6 +669,7 @@ class PublishedAllRagPipelineApi(Resource):
} }
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource): class RagPipelineByIdApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -713,6 +727,7 @@ class RagPipelineByIdApi(Resource):
return workflow return workflow
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
class PublishedRagPipelineSecondStepApi(Resource): class PublishedRagPipelineSecondStepApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -738,6 +753,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
} }
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
class PublishedRagPipelineFirstStepApi(Resource): class PublishedRagPipelineFirstStepApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -763,6 +779,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
} }
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
class DraftRagPipelineFirstStepApi(Resource): class DraftRagPipelineFirstStepApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -788,6 +805,7 @@ class DraftRagPipelineFirstStepApi(Resource):
} }
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
class DraftRagPipelineSecondStepApi(Resource): class DraftRagPipelineSecondStepApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -814,6 +832,7 @@ class DraftRagPipelineSecondStepApi(Resource):
} }
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
class RagPipelineWorkflowRunListApi(Resource): class RagPipelineWorkflowRunListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -835,6 +854,7 @@ class RagPipelineWorkflowRunListApi(Resource):
return result return result
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>")
class RagPipelineWorkflowRunDetailApi(Resource): class RagPipelineWorkflowRunDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -853,6 +873,7 @@ class RagPipelineWorkflowRunDetailApi(Resource):
return workflow_run return workflow_run
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>/node-executions")
class RagPipelineWorkflowRunNodeExecutionListApi(Resource): class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -876,6 +897,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
return {"data": node_executions} return {"data": node_executions}
@console_ns.route("/rag/pipelines/datasource-plugins")
class DatasourceListApi(Resource): class DatasourceListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -891,6 +913,7 @@ class DatasourceListApi(Resource):
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id)) return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run")
class RagPipelineWorkflowLastRunApi(Resource): class RagPipelineWorkflowLastRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -912,6 +935,7 @@ class RagPipelineWorkflowLastRunApi(Resource):
return node_exec return node_exec
@console_ns.route("/rag/pipelines/transform/datasets/<uuid:dataset_id>")
class RagPipelineTransformApi(Resource): class RagPipelineTransformApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -929,6 +953,7 @@ class RagPipelineTransformApi(Resource):
return result return result
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
class RagPipelineDatasourceVariableApi(Resource): class RagPipelineDatasourceVariableApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -958,6 +983,7 @@ class RagPipelineDatasourceVariableApi(Resource):
return workflow_node_execution return workflow_node_execution
@console_ns.route("/rag/pipelines/recommended-plugins")
class RagPipelineRecommendedPluginApi(Resource): class RagPipelineRecommendedPluginApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -966,114 +992,3 @@ class RagPipelineRecommendedPluginApi(Resource):
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins() recommended_plugins = rag_pipeline_service.get_recommended_plugins()
return recommended_plugins return recommended_plugins
api.add_resource(
DraftRagPipelineApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft",
)
api.add_resource(
DraftRagPipelineRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run",
)
api.add_resource(
PublishedRagPipelineRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/run",
)
api.add_resource(
RagPipelineTaskStopApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop",
)
api.add_resource(
RagPipelineDraftNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelinePublishedDatasourceNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelineDraftDatasourceNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelineDraftRunIterationNodeApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelineDraftRunLoopNodeApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run",
)
api.add_resource(
PublishedRagPipelineApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/publish",
)
api.add_resource(
PublishedAllRagPipelineApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows",
)
api.add_resource(
DefaultRagPipelineBlockConfigsApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs",
)
api.add_resource(
DefaultRagPipelineBlockConfigApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>",
)
api.add_resource(
RagPipelineByIdApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>",
)
api.add_resource(
RagPipelineWorkflowRunListApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs",
)
api.add_resource(
RagPipelineWorkflowRunDetailApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>",
)
api.add_resource(
RagPipelineWorkflowRunNodeExecutionListApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>/node-executions",
)
api.add_resource(
DatasourceListApi,
"/rag/pipelines/datasource-plugins",
)
api.add_resource(
PublishedRagPipelineSecondStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters",
)
api.add_resource(
PublishedRagPipelineFirstStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters",
)
api.add_resource(
DraftRagPipelineSecondStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters",
)
api.add_resource(
DraftRagPipelineFirstStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters",
)
api.add_resource(
RagPipelineWorkflowLastRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run",
)
api.add_resource(
RagPipelineTransformApi,
"/rag/pipelines/transform/datasets/<uuid:dataset_id>",
)
api.add_resource(
RagPipelineDatasourceVariableApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect",
)
api.add_resource(
RagPipelineRecommendedPluginApi,
"/rag/pipelines/recommended-plugins",
)

View File

@@ -26,9 +26,15 @@ from services.errors.audio import (
UnsupportedAudioTypeServiceError, UnsupportedAudioTypeServiceError,
) )
from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/audio-to-text",
endpoint="installed_app_audio",
)
class ChatAudioApi(InstalledAppResource): class ChatAudioApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
@@ -65,6 +71,10 @@ class ChatAudioApi(InstalledAppResource):
raise InternalServerError() raise InternalServerError()
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/text-to-audio",
endpoint="installed_app_text",
)
class ChatTextApi(InstalledAppResource): class ChatTextApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
from flask_restx import reqparse from flask_restx import reqparse

View File

@@ -33,10 +33,16 @@ from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# define completion api for user # define completion api for user
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/completion-messages",
endpoint="installed_app_completion",
)
class CompletionApi(InstalledAppResource): class CompletionApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
@@ -87,6 +93,10 @@ class CompletionApi(InstalledAppResource):
raise InternalServerError() raise InternalServerError()
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
endpoint="installed_app_stop_completion",
)
class CompletionStopApi(InstalledAppResource): class CompletionStopApi(InstalledAppResource):
def post(self, installed_app, task_id): def post(self, installed_app, task_id):
app_model = installed_app.app app_model = installed_app.app
@@ -100,6 +110,10 @@ class CompletionStopApi(InstalledAppResource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/chat-messages",
endpoint="installed_app_chat_completion",
)
class ChatApi(InstalledAppResource): class ChatApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
@@ -153,6 +167,10 @@ class ChatApi(InstalledAppResource):
raise InternalServerError() raise InternalServerError()
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
endpoint="installed_app_stop_chat_completion",
)
class ChatStopApi(InstalledAppResource): class ChatStopApi(InstalledAppResource):
def post(self, installed_app, task_id): def post(self, installed_app, task_id):
app_model = installed_app.app app_model = installed_app.app

View File

@@ -16,7 +16,13 @@ from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
from services.web_conversation_service import WebConversationService from services.web_conversation_service import WebConversationService
from .. import console_ns
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/conversations",
endpoint="installed_app_conversations",
)
class ConversationListApi(InstalledAppResource): class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, installed_app): def get(self, installed_app):
@@ -52,6 +58,10 @@ class ConversationListApi(InstalledAppResource):
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
endpoint="installed_app_conversation",
)
class ConversationApi(InstalledAppResource): class ConversationApi(InstalledAppResource):
def delete(self, installed_app, c_id): def delete(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
@@ -70,6 +80,10 @@ class ConversationApi(InstalledAppResource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
endpoint="installed_app_conversation_rename",
)
class ConversationRenameApi(InstalledAppResource): class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, installed_app, c_id): def post(self, installed_app, c_id):
@@ -95,6 +109,10 @@ class ConversationRenameApi(InstalledAppResource):
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
endpoint="installed_app_conversation_pin",
)
class ConversationPinApi(InstalledAppResource): class ConversationPinApi(InstalledAppResource):
def patch(self, installed_app, c_id): def patch(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
@@ -114,6 +132,10 @@ class ConversationPinApi(InstalledAppResource):
return {"result": "success"} return {"result": "success"}
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
endpoint="installed_app_conversation_unpin",
)
class ConversationUnPinApi(InstalledAppResource): class ConversationUnPinApi(InstalledAppResource):
def patch(self, installed_app, c_id): def patch(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app

View File

@@ -6,7 +6,7 @@ from flask_restx import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_, select from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.console import api from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db from extensions.ext_database import db
@@ -22,6 +22,7 @@ from services.feature_service import FeatureService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route("/installed-apps")
class InstalledAppsListApi(Resource): class InstalledAppsListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -154,6 +155,7 @@ class InstalledAppsListApi(Resource):
return {"message": "App installed successfully"} return {"message": "App installed successfully"}
@console_ns.route("/installed-apps/<uuid:installed_app_id>")
class InstalledAppApi(InstalledAppResource): class InstalledAppApi(InstalledAppResource):
""" """
update and delete an installed app update and delete an installed app
@@ -185,7 +187,3 @@ class InstalledAppApi(InstalledAppResource):
db.session.commit() db.session.commit()
return {"result": "success", "message": "App info updated successfully"} return {"result": "success", "message": "App info updated successfully"}
api.add_resource(InstalledAppsListApi, "/installed-apps")
api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>")

View File

@@ -36,9 +36,15 @@ from services.errors.message import (
) )
from services.message_service import MessageService from services.message_service import MessageService
from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/messages",
endpoint="installed_app_messages",
)
class MessageListApi(InstalledAppResource): class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, installed_app): def get(self, installed_app):
@@ -66,6 +72,10 @@ class MessageListApi(InstalledAppResource):
raise NotFound("First Message Not Exists.") raise NotFound("First Message Not Exists.")
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
endpoint="installed_app_message_feedback",
)
class MessageFeedbackApi(InstalledAppResource): class MessageFeedbackApi(InstalledAppResource):
def post(self, installed_app, message_id): def post(self, installed_app, message_id):
app_model = installed_app.app app_model = installed_app.app
@@ -93,6 +103,10 @@ class MessageFeedbackApi(InstalledAppResource):
return {"result": "success"} return {"result": "success"}
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
endpoint="installed_app_more_like_this",
)
class MessageMoreLikeThisApi(InstalledAppResource): class MessageMoreLikeThisApi(InstalledAppResource):
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
app_model = installed_app.app app_model = installed_app.app
@@ -139,6 +153,10 @@ class MessageMoreLikeThisApi(InstalledAppResource):
raise InternalServerError() raise InternalServerError()
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="installed_app_suggested_question",
)
class MessageSuggestedQuestionApi(InstalledAppResource): class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
app_model = installed_app.app app_model = installed_app.app

View File

@@ -1,7 +1,7 @@
from flask_restx import marshal_with from flask_restx import marshal_with
from controllers.common import fields from controllers.common import fields
from controllers.console import api from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
@@ -9,6 +9,7 @@ from models.model import AppMode, InstalledApp
from services.app_service import AppService from services.app_service import AppService
@console_ns.route("/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters")
class AppParameterApi(InstalledAppResource): class AppParameterApi(InstalledAppResource):
"""Resource for app variables.""" """Resource for app variables."""
@@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource):
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
class ExploreAppMetaApi(InstalledAppResource): class ExploreAppMetaApi(InstalledAppResource):
def get(self, installed_app: InstalledApp): def get(self, installed_app: InstalledApp):
"""Get app meta""" """Get app meta"""
@@ -46,9 +48,3 @@ class ExploreAppMetaApi(InstalledAppResource):
if not app_model: if not app_model:
raise ValueError("App not found") raise ValueError("App not found")
return AppService().get_app_meta(app_model) return AppService().get_app_meta(app_model)
api.add_resource(
AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters"
)
api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")

View File

@@ -1,7 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from constants.languages import languages from constants.languages import languages
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField from libs.helper import AppIconUrlField
from libs.login import current_user, login_required from libs.login import current_user, login_required
@@ -35,6 +35,7 @@ recommended_app_list_fields = {
} }
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource): class RecommendedAppListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -56,13 +57,10 @@ class RecommendedAppListApi(Resource):
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix) return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
@console_ns.route("/explore/apps/<uuid:app_id>")
class RecommendedAppApi(Resource): class RecommendedAppApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
app_id = str(app_id) app_id = str(app_id)
return RecommendedAppService.get_recommend_app_detail(app_id) return RecommendedAppService.get_recommend_app_detail(app_id)
api.add_resource(RecommendedAppListApi, "/explore/apps")
api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>")

View File

@@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
@@ -25,6 +25,7 @@ message_fields = {
} }
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource): class SavedMessageListApi(InstalledAppResource):
saved_message_infinite_scroll_pagination_fields = { saved_message_infinite_scroll_pagination_fields = {
"limit": fields.Integer, "limit": fields.Integer,
@@ -66,6 +67,9 @@ class SavedMessageListApi(InstalledAppResource):
return {"result": "success"} return {"result": "success"}
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>", endpoint="installed_app_saved_message"
)
class SavedMessageApi(InstalledAppResource): class SavedMessageApi(InstalledAppResource):
def delete(self, installed_app, message_id): def delete(self, installed_app, message_id):
app_model = installed_app.app app_model = installed_app.app
@@ -80,15 +84,3 @@ class SavedMessageApi(InstalledAppResource):
SavedMessageService.delete(app_model, current_user, message_id) SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}, 204 return {"result": "success"}, 204
api.add_resource(
SavedMessageListApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages",
endpoint="installed_app_saved_messages",
)
api.add_resource(
SavedMessageApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
endpoint="installed_app_saved_message",
)

View File

@@ -27,9 +27,12 @@ from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource): class InstalledAppWorkflowRunApi(InstalledAppResource):
def post(self, installed_app: InstalledApp): def post(self, installed_app: InstalledApp):
""" """
@@ -70,6 +73,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
raise InternalServerError() raise InternalServerError()
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop")
class InstalledAppWorkflowTaskStopApi(InstalledAppResource): class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
def post(self, installed_app: InstalledApp, task_id: str): def post(self, installed_app: InstalledApp, task_id: str):
""" """

View File

@@ -26,9 +26,12 @@ from libs.login import login_required
from models import Account from models import Account
from services.file_service import FileService from services.file_service import FileService
from . import console_ns
PREVIEW_WORDS_LIMIT = 3000 PREVIEW_WORDS_LIMIT = 3000
@console_ns.route("/files/upload")
class FileApi(Resource): class FileApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -88,6 +91,7 @@ class FileApi(Resource):
return upload_file, 201 return upload_file, 201
@console_ns.route("/files/<uuid:file_id>/preview")
class FilePreviewApi(Resource): class FilePreviewApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -98,6 +102,7 @@ class FilePreviewApi(Resource):
return {"content": text} return {"content": text}
@console_ns.route("/files/support-type")
class FileSupportTypeApi(Resource): class FileSupportTypeApi(Resource):
@setup_required @setup_required
@login_required @login_required

View File

@@ -19,7 +19,10 @@ from fields.file_fields import file_fields_with_signed_url, remote_file_info_fie
from models.account import Account from models.account import Account
from services.file_service import FileService from services.file_service import FileService
from . import console_ns
@console_ns.route("/remote-files/<path:url>")
class RemoteFileInfoApi(Resource): class RemoteFileInfoApi(Resource):
@marshal_with(remote_file_info_fields) @marshal_with(remote_file_info_fields)
def get(self, url): def get(self, url):
@@ -35,6 +38,7 @@ class RemoteFileInfoApi(Resource):
} }
@console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource): class RemoteFileUploadApi(Resource):
@marshal_with(file_fields_with_signed_url) @marshal_with(file_fields_with_signed_url)
def post(self): def post(self):

View File

@@ -2,7 +2,6 @@ import logging
from flask_restx import Resource from flask_restx import Resource
from controllers.console import api
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
setup_required, setup_required,
@@ -10,9 +9,12 @@ from controllers.console.wraps import (
from core.schemas.schema_manager import SchemaManager from core.schemas.schema_manager import SchemaManager
from libs.login import login_required from libs.login import login_required
from . import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route("/spec/schema-definitions")
class SpecSchemaDefinitionsApi(Resource): class SpecSchemaDefinitionsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -30,6 +32,3 @@ class SpecSchemaDefinitionsApi(Resource):
logger.exception("Failed to get schema definitions from local registry") logger.exception("Failed to get schema definitions from local registry")
# Return empty array as fallback # Return empty array as fallback
return [], 200 return [], 200
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")

View File

@@ -3,7 +3,7 @@ from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.tag_fields import dataset_tag_fields from fields.tag_fields import dataset_tag_fields
from libs.login import login_required from libs.login import login_required
@@ -17,6 +17,7 @@ def _validate_name(name):
return name return name
@console_ns.route("/tags")
class TagListApi(Resource): class TagListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -52,6 +53,7 @@ class TagListApi(Resource):
return response, 200 return response, 200
@console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource): class TagUpdateDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -89,6 +91,7 @@ class TagUpdateDeleteApi(Resource):
return 204 return 204
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource): class TagBindingCreateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -114,6 +117,7 @@ class TagBindingCreateApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource): class TagBindingDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -133,9 +137,3 @@ class TagBindingDeleteApi(Resource):
TagService.delete_tag_binding(args) TagService.delete_tag_binding(args)
return {"result": "success"}, 200 return {"result": "success"}, 200
api.add_resource(TagListApi, "/tags")
api.add_resource(TagUpdateDeleteApi, "/tags/<uuid:tag_id>")
api.add_resource(TagBindingCreateApi, "/tag-bindings/create")
api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove")

View File

@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
EmailAlreadyInUseError, EmailAlreadyInUseError,
EmailChangeLimitError, EmailChangeLimitError,
@@ -45,6 +45,7 @@ from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@console_ns.route("/account/init")
class AccountInitApi(Resource): class AccountInitApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -97,6 +98,7 @@ class AccountInitApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/account/profile")
class AccountProfileApi(Resource): class AccountProfileApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -109,6 +111,7 @@ class AccountProfileApi(Resource):
return current_user return current_user
@console_ns.route("/account/name")
class AccountNameApi(Resource): class AccountNameApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -130,6 +133,7 @@ class AccountNameApi(Resource):
return updated_account return updated_account
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource): class AccountAvatarApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -147,6 +151,7 @@ class AccountAvatarApi(Resource):
return updated_account return updated_account
@console_ns.route("/account/interface-language")
class AccountInterfaceLanguageApi(Resource): class AccountInterfaceLanguageApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -164,6 +169,7 @@ class AccountInterfaceLanguageApi(Resource):
return updated_account return updated_account
@console_ns.route("/account/interface-theme")
class AccountInterfaceThemeApi(Resource): class AccountInterfaceThemeApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -181,6 +187,7 @@ class AccountInterfaceThemeApi(Resource):
return updated_account return updated_account
@console_ns.route("/account/timezone")
class AccountTimezoneApi(Resource): class AccountTimezoneApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -202,6 +209,7 @@ class AccountTimezoneApi(Resource):
return updated_account return updated_account
@console_ns.route("/account/password")
class AccountPasswordApi(Resource): class AccountPasswordApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -227,6 +235,7 @@ class AccountPasswordApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/account/integrates")
class AccountIntegrateApi(Resource): class AccountIntegrateApi(Resource):
integrate_fields = { integrate_fields = {
"provider": fields.String, "provider": fields.String,
@@ -283,6 +292,7 @@ class AccountIntegrateApi(Resource):
return {"data": integrate_data} return {"data": integrate_data}
@console_ns.route("/account/delete/verify")
class AccountDeleteVerifyApi(Resource): class AccountDeleteVerifyApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -298,6 +308,7 @@ class AccountDeleteVerifyApi(Resource):
return {"result": "success", "data": token} return {"result": "success", "data": token}
@console_ns.route("/account/delete")
class AccountDeleteApi(Resource): class AccountDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -320,6 +331,7 @@ class AccountDeleteApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/account/delete/feedback")
class AccountDeleteUpdateFeedbackApi(Resource): class AccountDeleteUpdateFeedbackApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
@@ -333,6 +345,7 @@ class AccountDeleteUpdateFeedbackApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/account/education/verify")
class EducationVerifyApi(Resource): class EducationVerifyApi(Resource):
verify_fields = { verify_fields = {
"token": fields.String, "token": fields.String,
@@ -352,6 +365,7 @@ class EducationVerifyApi(Resource):
return BillingService.EducationIdentity.verify(account.id, account.email) return BillingService.EducationIdentity.verify(account.id, account.email)
@console_ns.route("/account/education")
class EducationApi(Resource): class EducationApi(Resource):
status_fields = { status_fields = {
"result": fields.Boolean, "result": fields.Boolean,
@@ -396,6 +410,7 @@ class EducationApi(Resource):
return res return res
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource): class EducationAutoCompleteApi(Resource):
data_fields = { data_fields = {
"data": fields.List(fields.String), "data": fields.List(fields.String),
@@ -419,6 +434,7 @@ class EducationAutoCompleteApi(Resource):
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"]) return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
@console_ns.route("/account/change-email")
class ChangeEmailSendEmailApi(Resource): class ChangeEmailSendEmailApi(Resource):
@enable_change_email @enable_change_email
@setup_required @setup_required
@@ -467,6 +483,7 @@ class ChangeEmailSendEmailApi(Resource):
return {"result": "success", "data": token} return {"result": "success", "data": token}
@console_ns.route("/account/change-email/validity")
class ChangeEmailCheckApi(Resource): class ChangeEmailCheckApi(Resource):
@enable_change_email @enable_change_email
@setup_required @setup_required
@@ -508,6 +525,7 @@ class ChangeEmailCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/account/change-email/reset")
class ChangeEmailResetApi(Resource): class ChangeEmailResetApi(Resource):
@enable_change_email @enable_change_email
@setup_required @setup_required
@@ -547,6 +565,7 @@ class ChangeEmailResetApi(Resource):
return updated_account return updated_account
@console_ns.route("/account/change-email/check-email-unique")
class CheckEmailUnique(Resource): class CheckEmailUnique(Resource):
@setup_required @setup_required
def post(self): def post(self):
@@ -558,28 +577,3 @@ class CheckEmailUnique(Resource):
if not AccountService.check_email_unique(args["email"]): if not AccountService.check_email_unique(args["email"]):
raise EmailAlreadyInUseError() raise EmailAlreadyInUseError()
return {"result": "success"} return {"result": "success"}
# Register API resources
api.add_resource(AccountInitApi, "/account/init")
api.add_resource(AccountProfileApi, "/account/profile")
api.add_resource(AccountNameApi, "/account/name")
api.add_resource(AccountAvatarApi, "/account/avatar")
api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language")
api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")
api.add_resource(AccountTimezoneApi, "/account/timezone")
api.add_resource(AccountPasswordApi, "/account/password")
api.add_resource(AccountIntegrateApi, "/account/integrates")
api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify")
api.add_resource(AccountDeleteApi, "/account/delete")
api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback")
api.add_resource(EducationVerifyApi, "/account/education/verify")
api.add_resource(EducationApi, "/account/education")
api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete")
# Change email
api.add_resource(ChangeEmailSendEmailApi, "/account/change-email")
api.add_resource(ChangeEmailCheckApi, "/account/change-email/validity")
api.add_resource(ChangeEmailResetApi, "/account/change-email/reset")
api.add_resource(CheckEmailUnique, "/account/change-email/check-email-unique")
# api.add_resource(AccountEmailApi, '/account/email')
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')

View File

@@ -1,7 +1,7 @@
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -10,6 +10,9 @@ from models.account import Account, TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService from services.model_load_balancing_service import ModelLoadBalancingService
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
)
class LoadBalancingCredentialsValidateApi(Resource): class LoadBalancingCredentialsValidateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -61,6 +64,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
return response return response
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
)
class LoadBalancingConfigCredentialsValidateApi(Resource): class LoadBalancingConfigCredentialsValidateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -111,15 +117,3 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
response["error"] = error response["error"] = error
return response return response
# Load Balancing Config
api.add_resource(
LoadBalancingCredentialsValidateApi,
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
)
api.add_resource(
LoadBalancingConfigCredentialsValidateApi,
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
)

View File

@@ -6,7 +6,7 @@ from flask_restx import Resource, marshal_with, reqparse
import services import services
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
CannotTransferOwnerToSelfError, CannotTransferOwnerToSelfError,
EmailCodeError, EmailCodeError,
@@ -33,6 +33,7 @@ from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService from services.feature_service import FeatureService
@console_ns.route("/workspaces/current/members")
class MemberListApi(Resource): class MemberListApi(Resource):
"""List all members of current tenant.""" """List all members of current tenant."""
@@ -49,6 +50,7 @@ class MemberListApi(Resource):
return {"result": "success", "accounts": members}, 200 return {"result": "success", "accounts": members}, 200
@console_ns.route("/workspaces/current/members/invite-email")
class MemberInviteEmailApi(Resource): class MemberInviteEmailApi(Resource):
"""Invite a new member by email.""" """Invite a new member by email."""
@@ -111,6 +113,7 @@ class MemberInviteEmailApi(Resource):
}, 201 }, 201
@console_ns.route("/workspaces/current/members/<uuid:member_id>")
class MemberCancelInviteApi(Resource): class MemberCancelInviteApi(Resource):
"""Cancel an invitation by member id.""" """Cancel an invitation by member id."""
@@ -143,6 +146,7 @@ class MemberCancelInviteApi(Resource):
}, 200 }, 200
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
class MemberUpdateRoleApi(Resource): class MemberUpdateRoleApi(Resource):
"""Update member role.""" """Update member role."""
@@ -177,6 +181,7 @@ class MemberUpdateRoleApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/workspaces/current/dataset-operators")
class DatasetOperatorMemberListApi(Resource): class DatasetOperatorMemberListApi(Resource):
"""List all members of current tenant.""" """List all members of current tenant."""
@@ -193,6 +198,7 @@ class DatasetOperatorMemberListApi(Resource):
return {"result": "success", "accounts": members}, 200 return {"result": "success", "accounts": members}, 200
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
class SendOwnerTransferEmailApi(Resource): class SendOwnerTransferEmailApi(Resource):
"""Send owner transfer email.""" """Send owner transfer email."""
@@ -233,6 +239,7 @@ class SendOwnerTransferEmailApi(Resource):
return {"result": "success", "data": token} return {"result": "success", "data": token}
@console_ns.route("/workspaces/current/members/owner-transfer-check")
class OwnerTransferCheckApi(Resource): class OwnerTransferCheckApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -278,6 +285,7 @@ class OwnerTransferCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
class OwnerTransfer(Resource): class OwnerTransfer(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -339,14 +347,3 @@ class OwnerTransfer(Resource):
raise ValueError(str(e)) raise ValueError(str(e))
return {"result": "success"} return {"result": "success"}
api.add_resource(MemberListApi, "/workspaces/current/members")
api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email")
api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/<uuid:member_id>")
api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members/<uuid:member_id>/update-role")
api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators")
# owner transfer
api.add_resource(SendOwnerTransferEmailApi, "/workspaces/current/members/send-owner-transfer-confirm-email")
api.add_resource(OwnerTransferCheckApi, "/workspaces/current/members/owner-transfer-check")
api.add_resource(OwnerTransfer, "/workspaces/current/members/<uuid:member_id>/owner-transfer")

View File

@@ -5,7 +5,7 @@ from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -17,6 +17,7 @@ from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService from services.model_provider_service import ModelProviderService
@console_ns.route("/workspaces/current/model-providers")
class ModelProviderListApi(Resource): class ModelProviderListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -45,6 +46,7 @@ class ModelProviderListApi(Resource):
return jsonable_encoder({"data": provider_list}) return jsonable_encoder({"data": provider_list})
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
class ModelProviderCredentialApi(Resource): class ModelProviderCredentialApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -151,6 +153,7 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
class ModelProviderCredentialSwitchApi(Resource): class ModelProviderCredentialSwitchApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -175,6 +178,7 @@ class ModelProviderCredentialSwitchApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
class ModelProviderValidateApi(Resource): class ModelProviderValidateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -211,6 +215,7 @@ class ModelProviderValidateApi(Resource):
return response return response
@console_ns.route("/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>")
class ModelProviderIconApi(Resource): class ModelProviderIconApi(Resource):
""" """
Get model provider icon Get model provider icon
@@ -229,6 +234,7 @@ class ModelProviderIconApi(Resource):
return send_file(io.BytesIO(icon), mimetype=mimetype) return send_file(io.BytesIO(icon), mimetype=mimetype)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
class PreferredProviderTypeUpdateApi(Resource): class PreferredProviderTypeUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -262,6 +268,7 @@ class PreferredProviderTypeUpdateApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/checkout-url")
class ModelProviderPaymentCheckoutUrlApi(Resource): class ModelProviderPaymentCheckoutUrlApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -281,21 +288,3 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
prefilled_email=current_user.email, prefilled_email=current_user.email,
) )
return data return data
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
api.add_resource(
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
)
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
api.add_resource(
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
)
api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
api.add_resource(
ModelProviderIconApi,
"/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
)

View File

@@ -4,7 +4,7 @@ from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -17,6 +17,7 @@ from services.model_provider_service import ModelProviderService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource): class DefaultModelApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -85,6 +86,7 @@ class DefaultModelApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
class ModelProviderModelApi(Resource): class ModelProviderModelApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -187,6 +189,7 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
class ModelProviderModelCredentialApi(Resource): class ModelProviderModelCredentialApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -364,6 +367,7 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
class ModelProviderModelCredentialSwitchApi(Resource): class ModelProviderModelCredentialSwitchApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -395,6 +399,9 @@ class ModelProviderModelCredentialSwitchApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
)
class ModelProviderModelEnableApi(Resource): class ModelProviderModelEnableApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -422,6 +429,9 @@ class ModelProviderModelEnableApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
)
class ModelProviderModelDisableApi(Resource): class ModelProviderModelDisableApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -449,6 +459,7 @@ class ModelProviderModelDisableApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
class ModelProviderModelValidateApi(Resource): class ModelProviderModelValidateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -494,6 +505,7 @@ class ModelProviderModelValidateApi(Resource):
return response return response
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
class ModelProviderModelParameterRuleApi(Resource): class ModelProviderModelParameterRuleApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -513,6 +525,7 @@ class ModelProviderModelParameterRuleApi(Resource):
return jsonable_encoder({"data": parameter_rules}) return jsonable_encoder({"data": parameter_rules})
@console_ns.route("/workspaces/current/models/model-types/<string:model_type>")
class ModelProviderAvailableModelApi(Resource): class ModelProviderAvailableModelApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -524,32 +537,3 @@ class ModelProviderAvailableModelApi(Resource):
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
return jsonable_encoder({"data": models}) return jsonable_encoder({"data": models})
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
api.add_resource(
ModelProviderModelEnableApi,
"/workspaces/current/model-providers/<path:provider>/models/enable",
endpoint="model-provider-model-enable",
)
api.add_resource(
ModelProviderModelDisableApi,
"/workspaces/current/model-providers/<path:provider>/models/disable",
endpoint="model-provider-model-disable",
)
api.add_resource(
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
)
api.add_resource(
ModelProviderModelCredentialSwitchApi,
"/workspaces/current/model-providers/<path:provider>/models/credentials/switch",
)
api.add_resource(
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
)
api.add_resource(
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
)
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
api.add_resource(DefaultModelApi, "/workspaces/current/default-model")

View File

@@ -6,7 +6,7 @@ from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@@ -19,6 +19,7 @@ from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService from services.plugin.plugin_service import PluginService
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource): class PluginDebuggingKeyApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -37,6 +38,7 @@ class PluginDebuggingKeyApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource): class PluginListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -55,6 +57,7 @@ class PluginListApi(Resource):
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
class PluginListLatestVersionsApi(Resource): class PluginListLatestVersionsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -72,6 +75,7 @@ class PluginListLatestVersionsApi(Resource):
return jsonable_encoder({"versions": versions}) return jsonable_encoder({"versions": versions})
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
class PluginListInstallationsFromIdsApi(Resource): class PluginListInstallationsFromIdsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -91,6 +95,7 @@ class PluginListInstallationsFromIdsApi(Resource):
return jsonable_encoder({"plugins": plugins}) return jsonable_encoder({"plugins": plugins})
@console_ns.route("/workspaces/current/plugin/icon")
class PluginIconApi(Resource): class PluginIconApi(Resource):
@setup_required @setup_required
def get(self): def get(self):
@@ -108,6 +113,7 @@ class PluginIconApi(Resource):
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@console_ns.route("/workspaces/current/plugin/upload/pkg")
class PluginUploadFromPkgApi(Resource): class PluginUploadFromPkgApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -131,6 +137,7 @@ class PluginUploadFromPkgApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/upload/github")
class PluginUploadFromGithubApi(Resource): class PluginUploadFromGithubApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -153,6 +160,7 @@ class PluginUploadFromGithubApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/upload/bundle")
class PluginUploadFromBundleApi(Resource): class PluginUploadFromBundleApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -176,6 +184,7 @@ class PluginUploadFromBundleApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/install/pkg")
class PluginInstallFromPkgApi(Resource): class PluginInstallFromPkgApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -201,6 +210,7 @@ class PluginInstallFromPkgApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/install/github")
class PluginInstallFromGithubApi(Resource): class PluginInstallFromGithubApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -230,6 +240,7 @@ class PluginInstallFromGithubApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/install/marketplace")
class PluginInstallFromMarketplaceApi(Resource): class PluginInstallFromMarketplaceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -255,6 +266,7 @@ class PluginInstallFromMarketplaceApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
class PluginFetchMarketplacePkgApi(Resource): class PluginFetchMarketplacePkgApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -280,6 +292,7 @@ class PluginFetchMarketplacePkgApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
class PluginFetchManifestApi(Resource): class PluginFetchManifestApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -304,6 +317,7 @@ class PluginFetchManifestApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks")
class PluginFetchInstallTasksApi(Resource): class PluginFetchInstallTasksApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -325,6 +339,7 @@ class PluginFetchInstallTasksApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>")
class PluginFetchInstallTaskApi(Resource): class PluginFetchInstallTaskApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -339,6 +354,7 @@ class PluginFetchInstallTaskApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete")
class PluginDeleteInstallTaskApi(Resource): class PluginDeleteInstallTaskApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -353,6 +369,7 @@ class PluginDeleteInstallTaskApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/delete_all")
class PluginDeleteAllInstallTaskItemsApi(Resource): class PluginDeleteAllInstallTaskItemsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -367,6 +384,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
class PluginDeleteInstallTaskItemApi(Resource): class PluginDeleteInstallTaskItemApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -381,6 +399,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
class PluginUpgradeFromMarketplaceApi(Resource): class PluginUpgradeFromMarketplaceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -404,6 +423,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/upgrade/github")
class PluginUpgradeFromGithubApi(Resource): class PluginUpgradeFromGithubApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -435,6 +455,7 @@ class PluginUpgradeFromGithubApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/uninstall")
class PluginUninstallApi(Resource): class PluginUninstallApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -453,6 +474,7 @@ class PluginUninstallApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/permission/change")
class PluginChangePermissionApi(Resource): class PluginChangePermissionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -475,6 +497,7 @@ class PluginChangePermissionApi(Resource):
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)} return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
@console_ns.route("/workspaces/current/plugin/permission/fetch")
class PluginFetchPermissionApi(Resource): class PluginFetchPermissionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -499,6 +522,7 @@ class PluginFetchPermissionApi(Resource):
) )
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
class PluginFetchDynamicSelectOptionsApi(Resource): class PluginFetchDynamicSelectOptionsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -535,6 +559,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options}) return jsonable_encoder({"options": options})
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource): class PluginChangePreferencesApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -590,6 +615,7 @@ class PluginChangePreferencesApi(Resource):
return jsonable_encoder({"success": True}) return jsonable_encoder({"success": True})
@console_ns.route("/workspaces/current/plugin/preferences/fetch")
class PluginFetchPreferencesApi(Resource): class PluginFetchPreferencesApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -628,6 +654,7 @@ class PluginFetchPreferencesApi(Resource):
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict}) return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
class PluginAutoUpgradeExcludePluginApi(Resource): class PluginAutoUpgradeExcludePluginApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -641,35 +668,3 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
args = req.parse_args() args = req.parse_args()
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])}) return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")
api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg")
api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github")
api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace")
api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github")
api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace")
api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest")
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marketplace/pkg")
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options")
api.add_resource(PluginFetchPreferencesApi, "/workspaces/current/plugin/preferences/fetch")
api.add_resource(PluginChangePreferencesApi, "/workspaces/current/plugin/preferences/change")
api.add_resource(PluginAutoUpgradeExcludePluginApi, "/workspaces/current/plugin/preferences/autoupgrade/exclude")

View File

@@ -10,7 +10,7 @@ from flask_restx import (
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
enterprise_license_required, enterprise_license_required,
@@ -47,6 +47,7 @@ def is_valid_url(url: str) -> bool:
return False return False
@console_ns.route("/workspaces/current/tool-providers")
class ToolProviderListApi(Resource): class ToolProviderListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -71,6 +72,7 @@ class ToolProviderListApi(Resource):
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None))
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/tools")
class ToolBuiltinProviderListToolsApi(Resource): class ToolBuiltinProviderListToolsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -88,6 +90,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/info")
class ToolBuiltinProviderInfoApi(Resource): class ToolBuiltinProviderInfoApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -100,6 +103,7 @@ class ToolBuiltinProviderInfoApi(Resource):
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete")
class ToolBuiltinProviderDeleteApi(Resource): class ToolBuiltinProviderDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -121,6 +125,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add")
class ToolBuiltinProviderAddApi(Resource): class ToolBuiltinProviderAddApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -150,6 +155,7 @@ class ToolBuiltinProviderAddApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update")
class ToolBuiltinProviderUpdateApi(Resource): class ToolBuiltinProviderUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -181,6 +187,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
return result return result
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credentials")
class ToolBuiltinProviderGetCredentialsApi(Resource): class ToolBuiltinProviderGetCredentialsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -196,6 +203,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/icon")
class ToolBuiltinProviderIconApi(Resource): class ToolBuiltinProviderIconApi(Resource):
@setup_required @setup_required
def get(self, provider): def get(self, provider):
@@ -204,6 +212,7 @@ class ToolBuiltinProviderIconApi(Resource):
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@console_ns.route("/workspaces/current/tool-provider/api/add")
class ToolApiProviderAddApi(Resource): class ToolApiProviderAddApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -243,6 +252,7 @@ class ToolApiProviderAddApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/api/remote")
class ToolApiProviderGetRemoteSchemaApi(Resource): class ToolApiProviderGetRemoteSchemaApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -266,6 +276,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/api/tools")
class ToolApiProviderListToolsApi(Resource): class ToolApiProviderListToolsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -291,6 +302,7 @@ class ToolApiProviderListToolsApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/api/update")
class ToolApiProviderUpdateApi(Resource): class ToolApiProviderUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -332,6 +344,7 @@ class ToolApiProviderUpdateApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/api/delete")
class ToolApiProviderDeleteApi(Resource): class ToolApiProviderDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -358,6 +371,7 @@ class ToolApiProviderDeleteApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/api/get")
class ToolApiProviderGetApi(Resource): class ToolApiProviderGetApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -381,6 +395,7 @@ class ToolApiProviderGetApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>")
class ToolBuiltinProviderCredentialsSchemaApi(Resource): class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -396,6 +411,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/api/schema")
class ToolApiProviderSchemaApi(Resource): class ToolApiProviderSchemaApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -412,6 +428,7 @@ class ToolApiProviderSchemaApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/api/test/pre")
class ToolApiProviderPreviousTestApi(Resource): class ToolApiProviderPreviousTestApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -439,6 +456,7 @@ class ToolApiProviderPreviousTestApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/workflow/create")
class ToolWorkflowProviderCreateApi(Resource): class ToolWorkflowProviderCreateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -478,6 +496,7 @@ class ToolWorkflowProviderCreateApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/workflow/update")
class ToolWorkflowProviderUpdateApi(Resource): class ToolWorkflowProviderUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -520,6 +539,7 @@ class ToolWorkflowProviderUpdateApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/workflow/delete")
class ToolWorkflowProviderDeleteApi(Resource): class ToolWorkflowProviderDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -545,6 +565,7 @@ class ToolWorkflowProviderDeleteApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/workflow/get")
class ToolWorkflowProviderGetApi(Resource): class ToolWorkflowProviderGetApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -579,6 +600,7 @@ class ToolWorkflowProviderGetApi(Resource):
return jsonable_encoder(tool) return jsonable_encoder(tool)
@console_ns.route("/workspaces/current/tool-provider/workflow/tools")
class ToolWorkflowProviderListToolApi(Resource): class ToolWorkflowProviderListToolApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -603,6 +625,7 @@ class ToolWorkflowProviderListToolApi(Resource):
) )
@console_ns.route("/workspaces/current/tools/builtin")
class ToolBuiltinListApi(Resource): class ToolBuiltinListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -624,6 +647,7 @@ class ToolBuiltinListApi(Resource):
) )
@console_ns.route("/workspaces/current/tools/api")
class ToolApiListApi(Resource): class ToolApiListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -642,6 +666,7 @@ class ToolApiListApi(Resource):
) )
@console_ns.route("/workspaces/current/tools/workflow")
class ToolWorkflowListApi(Resource): class ToolWorkflowListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -663,6 +688,7 @@ class ToolWorkflowListApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-labels")
class ToolLabelsApi(Resource): class ToolLabelsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -672,6 +698,7 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels()) return jsonable_encoder(ToolLabelsService.list_tool_labels())
@console_ns.route("/oauth/plugin/<path:provider>/tool/authorization-url")
class ToolPluginOAuthApi(Resource): class ToolPluginOAuthApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -716,6 +743,7 @@ class ToolPluginOAuthApi(Resource):
return response return response
@console_ns.route("/oauth/plugin/<path:provider>/tool/callback")
class ToolOAuthCallback(Resource): class ToolOAuthCallback(Resource):
@setup_required @setup_required
def get(self, provider): def get(self, provider):
@@ -766,6 +794,7 @@ class ToolOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential")
class ToolBuiltinProviderSetDefaultApi(Resource): class ToolBuiltinProviderSetDefaultApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -779,6 +808,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
class ToolOAuthCustomClient(Resource): class ToolOAuthCustomClient(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -822,6 +852,7 @@ class ToolOAuthCustomClient(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema")
class ToolBuiltinProviderGetOauthClientSchemaApi(Resource): class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -834,6 +865,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credential/info")
class ToolBuiltinProviderGetCredentialInfoApi(Resource): class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -849,6 +881,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/mcp")
class ToolProviderMCPApi(Resource): class ToolProviderMCPApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -933,6 +966,7 @@ class ToolProviderMCPApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
class ToolMCPAuthApi(Resource): class ToolMCPAuthApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -978,6 +1012,7 @@ class ToolMCPAuthApi(Resource):
raise ValueError(f"Failed to connect to MCP server: {e}") from e raise ValueError(f"Failed to connect to MCP server: {e}") from e
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
class ToolMCPDetailApi(Resource): class ToolMCPDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -988,6 +1023,7 @@ class ToolMCPDetailApi(Resource):
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@console_ns.route("/workspaces/current/tools/mcp")
class ToolMCPListAllApi(Resource): class ToolMCPListAllApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -1001,6 +1037,7 @@ class ToolMCPListAllApi(Resource):
return [tool.to_dict() for tool in tools] return [tool.to_dict() for tool in tools]
@console_ns.route("/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
class ToolMCPUpdateApi(Resource): class ToolMCPUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -1014,6 +1051,7 @@ class ToolMCPUpdateApi(Resource):
return jsonable_encoder(tools) return jsonable_encoder(tools)
@console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource): class ToolMCPCallbackApi(Resource):
def get(self): def get(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@@ -1024,67 +1062,3 @@ class ToolMCPCallbackApi(Resource):
authorization_code = args["code"] authorization_code = args["code"]
handle_callback(state_key, authorization_code) handle_callback(state_key, authorization_code)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
# tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
# tool oauth
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/<path:provider>/tool/authorization-url")
api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback")
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin/<path:provider>/add")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/default-credential"
)
api.add_resource(
ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credential/info"
)
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
)
api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>",
)
api.add_resource(
ToolBuiltinProviderGetOauthClientSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema",
)
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
# api tool provider
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")
api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote")
api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools")
api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update")
api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete")
api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get")
api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema")
api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre")
# workflow tool provider
api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create")
api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update")
api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete")
api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get")
api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools")
# mcp tool provider
api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp")
api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth")
api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback")
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")
api.add_resource(ToolMCPListAllApi, "/workspaces/current/tools/mcp")
api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow")
api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels")

View File

@@ -14,7 +14,7 @@ from controllers.common.errors import (
TooManyFilesError, TooManyFilesError,
UnsupportedFileTypeError, UnsupportedFileTypeError,
) )
from controllers.console import api from controllers.console import console_ns
from controllers.console.admin import admin_required from controllers.console.admin import admin_required
from controllers.console.error import AccountNotLinkTenantError from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import ( from controllers.console.wraps import (
@@ -65,6 +65,7 @@ tenants_fields = {
workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField} workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField}
@console_ns.route("/workspaces")
class TenantListApi(Resource): class TenantListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -93,6 +94,7 @@ class TenantListApi(Resource):
return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200 return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200
@console_ns.route("/all-workspaces")
class WorkspaceListApi(Resource): class WorkspaceListApi(Resource):
@setup_required @setup_required
@admin_required @admin_required
@@ -118,6 +120,8 @@ class WorkspaceListApi(Resource):
}, 200 }, 200
@console_ns.route("/workspaces/current", endpoint="workspaces_current")
@console_ns.route("/info", endpoint="info") # Deprecated
class TenantApi(Resource): class TenantApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -143,11 +147,10 @@ class TenantApi(Resource):
else: else:
raise Unauthorized("workspace is archived") raise Unauthorized("workspace is archived")
if not tenant:
raise ValueError("No tenant available")
return WorkspaceService.get_tenant_info(tenant), 200 return WorkspaceService.get_tenant_info(tenant), 200
@console_ns.route("/workspaces/switch")
class SwitchWorkspaceApi(Resource): class SwitchWorkspaceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -172,6 +175,7 @@ class SwitchWorkspaceApi(Resource):
return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
@console_ns.route("/workspaces/custom-config")
class CustomConfigWorkspaceApi(Resource): class CustomConfigWorkspaceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -202,6 +206,7 @@ class CustomConfigWorkspaceApi(Resource):
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
@console_ns.route("/workspaces/custom-config/webapp-logo/upload")
class WebappLogoWorkspaceApi(Resource): class WebappLogoWorkspaceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -242,6 +247,7 @@ class WebappLogoWorkspaceApi(Resource):
return {"id": upload_file.id}, 201 return {"id": upload_file.id}, 201
@console_ns.route("/workspaces/info")
class WorkspaceInfoApi(Resource): class WorkspaceInfoApi(Resource):
@setup_required @setup_required
@login_required @login_required
@@ -261,13 +267,3 @@ class WorkspaceInfoApi(Resource):
db.session.commit() db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants
api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants
api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info
api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated
api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant
api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config")
api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload")
api.add_resource(WorkspaceInfoApi, "/workspaces/info") # POST for changing workspace info

View File

@@ -128,7 +128,7 @@ def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseMo
raise ValueError("invalid json") raise ValueError("invalid json")
try: try:
payload = payload_type(**data) payload = payload_type.model_validate(data)
except Exception as e: except Exception as e:
raise ValueError(f"invalid payload: {str(e)}") raise ValueError(f"invalid payload: {str(e)}")

View File

@@ -1,10 +1,10 @@
from typing import Literal from typing import Any, Literal, cast
from flask import request from flask import request
from flask_restx import marshal, reqparse from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services.dataset_service import services
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
from controllers.service_api.wraps import ( from controllers.service_api.wraps import (
@@ -17,6 +17,7 @@ from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields from fields.tag_fields import build_dataset_tag_fields
from libs.login import current_user from libs.login import current_user
from libs.validators import validate_description_length
from models.account import Account from models.account import Account
from models.dataset import Dataset, DatasetPermissionEnum from models.dataset import Dataset, DatasetPermissionEnum
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
@@ -31,12 +32,6 @@ def _validate_name(name):
return name return name
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
# Define parsers for dataset operations # Define parsers for dataset operations
dataset_create_parser = reqparse.RequestParser() dataset_create_parser = reqparse.RequestParser()
dataset_create_parser.add_argument( dataset_create_parser.add_argument(
@@ -48,7 +43,7 @@ dataset_create_parser.add_argument(
) )
dataset_create_parser.add_argument( dataset_create_parser.add_argument(
"description", "description",
type=_validate_description_length, type=validate_description_length,
nullable=True, nullable=True,
required=False, required=False,
default="", default="",
@@ -101,7 +96,7 @@ dataset_update_parser.add_argument(
type=_validate_name, type=_validate_name,
) )
dataset_update_parser.add_argument( dataset_update_parser.add_argument(
"description", location="json", store_missing=False, type=_validate_description_length "description", location="json", store_missing=False, type=validate_description_length
) )
dataset_update_parser.add_argument( dataset_update_parser.add_argument(
"indexing_technique", "indexing_technique",
@@ -254,19 +249,21 @@ class DatasetListApi(DatasetApiResource):
"""Resource for creating datasets.""" """Resource for creating datasets."""
args = dataset_create_parser.parse_args() args = dataset_create_parser.parse_args()
if args.get("embedding_model_provider"): embedding_model_provider = args.get("embedding_model_provider")
DatasetService.check_embedding_model_setting( embedding_model = args.get("embedding_model")
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") if embedding_model_provider and embedding_model:
) DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model")
if ( if (
args.get("retrieval_model") retrieval_model
and args.get("retrieval_model").get("reranking_model") and retrieval_model.get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") and retrieval_model.get("reranking_model").get("reranking_provider_name")
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
tenant_id, tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), retrieval_model.get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), retrieval_model.get("reranking_model").get("reranking_model_name"),
) )
try: try:
@@ -283,7 +280,7 @@ class DatasetListApi(DatasetApiResource):
external_knowledge_id=args["external_knowledge_id"], external_knowledge_id=args["external_knowledge_id"],
embedding_model_provider=args["embedding_model_provider"], embedding_model_provider=args["embedding_model_provider"],
embedding_model_name=args["embedding_model"], embedding_model_name=args["embedding_model"],
retrieval_model=RetrievalModel(**args["retrieval_model"]) retrieval_model=RetrievalModel.model_validate(args["retrieval_model"])
if args["retrieval_model"] is not None if args["retrieval_model"] is not None
else None, else None,
) )
@@ -317,7 +314,7 @@ class DatasetApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
# check embedding setting # check embedding setting
provider_manager = ProviderManager() provider_manager = ProviderManager()
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
@@ -331,8 +328,8 @@ class DatasetApi(DatasetApiResource):
for embedding_model in embedding_models: for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data["indexing_technique"] == "high_quality": if data.get("indexing_technique") == "high_quality":
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
if item_model in model_names: if item_model in model_names:
data["embedding_available"] = True data["embedding_available"] = True
else: else:
@@ -341,7 +338,9 @@ class DatasetApi(DatasetApiResource):
data["embedding_available"] = True data["embedding_available"] = True
# force update search method to keyword_search if indexing_technique is economic # force update search method to keyword_search if indexing_technique is economic
data["retrieval_model_dict"]["search_method"] = "keyword_search" retrieval_model_dict = data.get("retrieval_model_dict")
if retrieval_model_dict:
retrieval_model_dict["search_method"] = "keyword_search"
if data.get("permission") == "partial_members": if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@@ -372,19 +371,24 @@ class DatasetApi(DatasetApiResource):
data = request.get_json() data = request.get_json()
# check embedding model setting # check embedding model setting
if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"): embedding_model_provider = data.get("embedding_model_provider")
DatasetService.check_embedding_model_setting( embedding_model = data.get("embedding_model")
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
) if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(
dataset.tenant_id, embedding_model_provider, embedding_model
)
retrieval_model = data.get("retrieval_model")
if ( if (
data.get("retrieval_model") retrieval_model
and data.get("retrieval_model").get("reranking_model") and retrieval_model.get("reranking_model")
and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name") and retrieval_model.get("reranking_model").get("reranking_provider_name")
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
dataset.tenant_id, dataset.tenant_id,
data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), retrieval_model.get("reranking_model").get("reranking_provider_name"),
data.get("retrieval_model").get("reranking_model").get("reranking_model_name"), retrieval_model.get("reranking_model").get("reranking_model_name"),
) )
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
@@ -397,7 +401,7 @@ class DatasetApi(DatasetApiResource):
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
result_data = marshal(dataset, dataset_detail_fields) result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@@ -591,9 +595,10 @@ class DatasetTagsApi(DatasetApiResource):
args = tag_update_parser.parse_args() args = tag_update_parser.parse_args()
args["type"] = "knowledge" args["type"] = "knowledge"
tag = TagService.update_tags(args, args.get("tag_id")) tag_id = args["tag_id"]
tag = TagService.update_tags(args, tag_id)
binding_count = TagService.get_tag_binding_count(args.get("tag_id")) binding_count = TagService.get_tag_binding_count(tag_id)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
@@ -616,7 +621,7 @@ class DatasetTagsApi(DatasetApiResource):
if not current_user.has_edit_permission: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
args = tag_delete_parser.parse_args() args = tag_delete_parser.parse_args()
TagService.delete_tag(args.get("tag_id")) TagService.delete_tag(args["tag_id"])
return 204 return 204

View File

@@ -108,19 +108,21 @@ class DocumentAddByTextApi(DatasetApiResource):
if text is None or name is None: if text is None or name is None:
raise ValueError("Both 'text' and 'name' must be non-null values.") raise ValueError("Both 'text' and 'name' must be non-null values.")
if args.get("embedding_model_provider"): embedding_model_provider = args.get("embedding_model_provider")
DatasetService.check_embedding_model_setting( embedding_model = args.get("embedding_model")
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") if embedding_model_provider and embedding_model:
) DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model")
if ( if (
args.get("retrieval_model") retrieval_model
and args.get("retrieval_model").get("reranking_model") and retrieval_model.get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") and retrieval_model.get("reranking_model").get("reranking_provider_name")
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
tenant_id, tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), retrieval_model.get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), retrieval_model.get("reranking_model").get("reranking_model_name"),
) )
if not current_user: if not current_user:
@@ -134,7 +136,7 @@ class DocumentAddByTextApi(DatasetApiResource):
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
} }
args["data_source"] = data_source args["data_source"] = data_source
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig.model_validate(args)
# validate args # validate args
DocumentService.document_create_args_validate(knowledge_config) DocumentService.document_create_args_validate(knowledge_config)
@@ -187,15 +189,16 @@ class DocumentUpdateByTextApi(DatasetApiResource):
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
retrieval_model = args.get("retrieval_model")
if ( if (
args.get("retrieval_model") retrieval_model
and args.get("retrieval_model").get("reranking_model") and retrieval_model.get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") and retrieval_model.get("reranking_model").get("reranking_provider_name")
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
tenant_id, tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), retrieval_model.get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), retrieval_model.get("reranking_model").get("reranking_model_name"),
) )
# indexing_technique is already set in dataset since this is an update # indexing_technique is already set in dataset since this is an update
@@ -218,7 +221,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
args["original_document_id"] = str(document_id) args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config) DocumentService.document_create_args_validate(knowledge_config)
try: try:
@@ -325,7 +328,7 @@ class DocumentAddByFileApi(DatasetApiResource):
} }
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config) DocumentService.document_create_args_validate(knowledge_config)
dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
@@ -423,7 +426,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
# validate args # validate args
args["original_document_id"] = str(document_id) args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config) DocumentService.document_create_args_validate(knowledge_config)
try: try:

View File

@@ -51,7 +51,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
"""Create metadata for a dataset.""" """Create metadata for a dataset."""
args = metadata_create_parser.parse_args() args = metadata_create_parser.parse_args()
metadata_args = MetadataArgs(**args) metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@@ -106,7 +106,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"])
return marshal(metadata, dataset_metadata_fields), 200 return marshal(metadata, dataset_metadata_fields), 200
@service_api_ns.doc("delete_dataset_metadata") @service_api_ns.doc("delete_dataset_metadata")
@@ -200,7 +200,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
args = document_metadata_parser.parse_args() args = document_metadata_parser.parse_args()
metadata_args = MetadataOperationData(**args) metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args) MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@@ -98,7 +98,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
parser.add_argument("is_published", type=bool, required=True, location="json") parser.add_argument("is_published", type=bool, required=True, location="json")
args: ParseResult = parser.parse_args() args: ParseResult = parser.parse_args()
datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args) datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService() rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id) pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)

View File

@@ -252,7 +252,7 @@ class DatasetSegmentApi(DatasetApiResource):
args = segment_update_parser.parse_args() args = segment_update_parser.parse_args()
updated_segment = SegmentService.update_segment( updated_segment = SegmentService.update_segment(
SegmentUpdateArgs(**args["segment"]), segment, document, dataset SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
) )
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200

View File

@@ -126,6 +126,8 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
end_user_id = enterprise_user_decoded.get("end_user_id") end_user_id = enterprise_user_decoded.get("end_user_id")
session_id = enterprise_user_decoded.get("session_id") session_id = enterprise_user_decoded.get("session_id")
user_auth_type = enterprise_user_decoded.get("auth_type") user_auth_type = enterprise_user_decoded.get("auth_type")
exchanged_token_expires_unix = enterprise_user_decoded.get("exp")
if not user_auth_type: if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.") raise Unauthorized("Missing auth_type in the token.")
@@ -169,8 +171,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
) )
db.session.add(end_user) db.session.add(end_user)
db.session.commit() db.session.commit()
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
exp = int(exp_dt.timestamp()) exp = int((datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)).timestamp())
if exchanged_token_expires_unix:
exp = int(exchanged_token_expires_unix)
payload = { payload = {
"iss": site.id, "iss": site.id,
"sub": "Web API Passport", "sub": "Web API Passport",

View File

@@ -40,7 +40,7 @@ class AgentConfigManager:
"credential_id": tool.get("credential_id", None), "credential_id": tool.get("credential_id", None),
} }
agent_tools.append(AgentToolEntity(**agent_tool_properties)) agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in { if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
"react_router", "react_router",

View File

@@ -1,4 +1,5 @@
import uuid import uuid
from typing import Literal, cast
from core.app.app_config.entities import ( from core.app.app_config.entities import (
DatasetEntity, DatasetEntity,
@@ -74,6 +75,9 @@ class DatasetConfigManager:
return None return None
query_variable = config.get("dataset_query_variable") query_variable = config.get("dataset_query_variable")
metadata_model_config_dict = dataset_configs.get("metadata_model_config")
metadata_filtering_conditions_dict = dataset_configs.get("metadata_filtering_conditions")
if dataset_configs["retrieval_model"] == "single": if dataset_configs["retrieval_model"] == "single":
return DatasetEntity( return DatasetEntity(
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
@@ -82,18 +86,23 @@ class DatasetConfigManager:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"] dataset_configs["retrieval_model"]
), ),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"), metadata_filtering_mode=cast(
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config")) Literal["disabled", "automatic", "manual"],
if dataset_configs.get("metadata_model_config") dataset_configs.get("metadata_filtering_mode", "disabled"),
),
metadata_model_config=ModelConfig(**metadata_model_config_dict)
if isinstance(metadata_model_config_dict, dict)
else None, else None,
metadata_filtering_conditions=MetadataFilteringCondition( metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
**dataset_configs.get("metadata_filtering_conditions", {}) if isinstance(metadata_filtering_conditions_dict, dict)
)
if dataset_configs.get("metadata_filtering_conditions")
else None, else None,
), ),
) )
else: else:
score_threshold_val = dataset_configs.get("score_threshold")
reranking_model_val = dataset_configs.get("reranking_model")
weights_val = dataset_configs.get("weights")
return DatasetEntity( return DatasetEntity(
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity( retrieve_config=DatasetRetrieveConfigEntity(
@@ -101,22 +110,23 @@ class DatasetConfigManager:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"] dataset_configs["retrieval_model"]
), ),
top_k=dataset_configs.get("top_k", 4), top_k=int(dataset_configs.get("top_k", 4)),
score_threshold=dataset_configs.get("score_threshold") score_threshold=float(score_threshold_val)
if dataset_configs.get("score_threshold_enabled", False) if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
else None, else None,
reranking_model=dataset_configs.get("reranking_model"), reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
weights=dataset_configs.get("weights"), weights=weights_val if isinstance(weights_val, dict) else None,
reranking_enabled=dataset_configs.get("reranking_enabled", True), reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"), metadata_filtering_mode=cast(
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config")) Literal["disabled", "automatic", "manual"],
if dataset_configs.get("metadata_model_config") dataset_configs.get("metadata_filtering_mode", "disabled"),
),
metadata_model_config=ModelConfig(**metadata_model_config_dict)
if isinstance(metadata_model_config_dict, dict)
else None, else None,
metadata_filtering_conditions=MetadataFilteringCondition( metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
**dataset_configs.get("metadata_filtering_conditions", {}) if isinstance(metadata_filtering_conditions_dict, dict)
)
if dataset_configs.get("metadata_filtering_conditions")
else None, else None,
), ),
) )
@@ -134,18 +144,17 @@ class DatasetConfigManager:
config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config) config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config)
# dataset_configs # dataset_configs
if not config.get("dataset_configs"): if "dataset_configs" not in config or not config.get("dataset_configs"):
config["dataset_configs"] = {"retrieval_model": "single"} config["dataset_configs"] = {}
config["dataset_configs"]["retrieval_model"] = config["dataset_configs"].get("retrieval_model", "single")
if not isinstance(config["dataset_configs"], dict): if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type") raise ValueError("dataset_configs must be of object type")
if not config["dataset_configs"].get("datasets"): if "datasets" not in config["dataset_configs"] or not config["dataset_configs"].get("datasets"):
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []} config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get( need_manual_query_datasets = config.get("dataset_configs", {}).get("datasets", {}).get("datasets")
"datasets", {}
).get("datasets")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION: if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion # Only check when mode is completion
@@ -166,8 +175,8 @@ class DatasetConfigManager:
:param config: app model config args :param config: app model config args
""" """
# Extract dataset config for legacy compatibility # Extract dataset config for legacy compatibility
if not config.get("agent_mode"): if "agent_mode" not in config or not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []} config["agent_mode"] = {}
if not isinstance(config["agent_mode"], dict): if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type") raise ValueError("agent_mode must be of object type")
@@ -180,19 +189,22 @@ class DatasetConfigManager:
raise ValueError("enabled in agent_mode must be of boolean type") raise ValueError("enabled in agent_mode must be of boolean type")
# tools # tools
if not config["agent_mode"].get("tools"): if "tools" not in config["agent_mode"] or not config["agent_mode"].get("tools"):
config["agent_mode"]["tools"] = [] config["agent_mode"]["tools"] = []
if not isinstance(config["agent_mode"]["tools"], list): if not isinstance(config["agent_mode"]["tools"], list):
raise ValueError("tools in agent_mode must be a list of objects") raise ValueError("tools in agent_mode must be a list of objects")
# strategy # strategy
if not config["agent_mode"].get("strategy"): if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
has_datasets = False has_datasets = False
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}: if config.get("agent_mode", {}).get("strategy") in {
for tool in config["agent_mode"]["tools"]: PlanningStrategy.ROUTER.value,
PlanningStrategy.REACT_ROUTER.value,
}:
for tool in config.get("agent_mode", {}).get("tools", []):
key = list(tool.keys())[0] key = list(tool.keys())[0]
if key == "dataset": if key == "dataset":
# old style, use tool name as key # old style, use tool name as key
@@ -217,7 +229,7 @@ class DatasetConfigManager:
has_datasets = True has_datasets = True
need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"] need_manual_query_datasets = has_datasets and config.get("agent_mode", {}).get("enabled")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION: if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion # Only check when mode is completion

View File

@@ -1,9 +1,11 @@
import logging
import queue import queue
import time import time
from abc import abstractmethod from abc import abstractmethod
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Any from typing import Any
from redis.exceptions import RedisError
from sqlalchemy.orm import DeclarativeMeta from sqlalchemy.orm import DeclarativeMeta
from configs import dify_config from configs import dify_config
@@ -18,6 +20,8 @@ from core.app.entities.queue_entities import (
) )
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class PublishFrom(IntEnum): class PublishFrom(IntEnum):
APPLICATION_MANAGER = auto() APPLICATION_MANAGER = auto()
@@ -35,9 +39,8 @@ class AppQueueManager:
self.invoke_from = invoke_from # Public accessor for invoke_from self.invoke_from = invoke_from # Public accessor for invoke_from
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
redis_client.setex( self._task_belong_cache_key = AppQueueManager._generate_task_belong_cache_key(self._task_id)
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" redis_client.setex(self._task_belong_cache_key, 1800, f"{user_prefix}-{self._user_id}")
)
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
@@ -79,9 +82,21 @@ class AppQueueManager:
Stop listen to queue Stop listen to queue
:return: :return:
""" """
self._clear_task_belong_cache()
self._q.put(None) self._q.put(None)
def publish_error(self, e, pub_from: PublishFrom): def _clear_task_belong_cache(self) -> None:
"""
Remove the task belong cache key once listening is finished.
"""
try:
redis_client.delete(self._task_belong_cache_key)
except RedisError:
logger.exception(
"Failed to clear task belong cache for task %s (key: %s)", self._task_id, self._task_belong_cache_key
)
def publish_error(self, e, pub_from: PublishFrom) -> None:
""" """
Publish error Publish error
:param e: error :param e: error

View File

@@ -61,9 +61,6 @@ class AppRunner:
if model_context_tokens is None: if model_context_tokens is None:
return -1 return -1
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_context_tokens: if prompt_tokens + max_tokens > model_context_tokens:

View File

@@ -116,7 +116,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
rag_pipeline_variables = [] rag_pipeline_variables = []
if workflow.rag_pipeline_variables: if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables: for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v) rag_pipeline_variable = RAGPipelineVariable.model_validate(v)
if ( if (
rag_pipeline_variable.belong_to_node_id rag_pipeline_variable.belong_to_node_id
in (self.application_generate_entity.start_node_id, "shared") in (self.application_generate_entity.start_node_id, "shared")

View File

@@ -107,7 +107,6 @@ class MessageCycleManager:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id) logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
db.session.merge(conversation)
db.session.commit() db.session.commit()
db.session.close() db.session.close()

View File

@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from openai import BaseModel from pydantic import BaseModel, Field
from pydantic import Field
# Import InvokeFrom locally to avoid circular import # Import InvokeFrom locally to avoid circular import
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, model_validator
class I18nObject(BaseModel): class I18nObject(BaseModel):
@@ -11,11 +11,12 @@ class I18nObject(BaseModel):
pt_BR: str | None = Field(default=None) pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None) ja_JP: str | None = Field(default=None)
def __init__(self, **data): @model_validator(mode="after")
super().__init__(**data) def _(self):
self.zh_Hans = self.zh_Hans or self.en_US self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US self.ja_JP = self.ja_JP or self.en_US
return self
def to_dict(self) -> dict: def to_dict(self) -> dict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@@ -5,7 +5,7 @@ from collections import defaultdict
from collections.abc import Iterator, Sequence from collections.abc import Iterator, Sequence
from json import JSONDecodeError from json import JSONDecodeError
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, model_validator
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -73,9 +73,8 @@ class ProviderConfiguration(BaseModel):
# pydantic configs # pydantic configs
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
def __init__(self, **data): @model_validator(mode="after")
super().__init__(**data) def _(self):
if self.provider.provider not in original_provider_configurate_methods: if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = [] original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in self.provider.configurate_methods: for configurate_method in self.provider.configurate_methods:
@@ -90,6 +89,7 @@ class ProviderConfiguration(BaseModel):
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
): ):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
return self
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
""" """

View File

@@ -1,13 +1,13 @@
from typing import cast from typing import cast
import requests import httpx
from configs import dify_config from configs import dify_config
from models.api_based_extension import APIBasedExtensionPoint from models.api_based_extension import APIBasedExtensionPoint
class APIBasedExtensionRequestor: class APIBasedExtensionRequestor:
timeout: tuple[int, int] = (5, 60) timeout: httpx.Timeout = httpx.Timeout(60.0, connect=5.0)
"""timeout for request connect and read""" """timeout for request connect and read"""
def __init__(self, api_endpoint: str, api_key: str): def __init__(self, api_endpoint: str, api_key: str):
@@ -27,25 +27,23 @@ class APIBasedExtensionRequestor:
url = self.api_endpoint url = self.api_endpoint
try: try:
# proxy support for security mounts: dict[str, httpx.BaseTransport] | None = None
proxies = None
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxies = { mounts = {
"http": dify_config.SSRF_PROXY_HTTP_URL, "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
"https": dify_config.SSRF_PROXY_HTTPS_URL, "https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
} }
response = requests.request( with httpx.Client(mounts=mounts, timeout=self.timeout) as client:
method="POST", response = client.request(
url=url, method="POST",
json={"point": point.value, "params": params}, url=url,
headers=headers, json={"point": point.value, "params": params},
timeout=self.timeout, headers=headers,
proxies=proxies, )
) except httpx.TimeoutException:
except requests.Timeout:
raise ValueError("request timeout") raise ValueError("request timeout")
except requests.ConnectionError: except httpx.RequestError:
raise ValueError("request connection error") raise ValueError("request connection error")
if response.status_code != 200: if response.status_code != 200:

View File

@@ -131,7 +131,7 @@ class CodeExecutor:
if (code := response_data.get("code")) != 0: if (code := response_data.get("code")) != 0:
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}") raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}")
response_code = CodeExecutionResponse(**response_data) response_code = CodeExecutionResponse.model_validate(response_data)
if response_code.data.error: if response_code.data.error:
raise CodeExecutionError(response_code.data.error) raise CodeExecutionError(response_code.data.error)

View File

@@ -26,7 +26,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version}) response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
response.raise_for_status() response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]] return [MarketplacePluginDeclaration.model_validate(plugin) for plugin in response.json()["data"]["plugins"]]
def batch_fetch_plugin_manifests_ignore_deserialization_error( def batch_fetch_plugin_manifests_ignore_deserialization_error(
@@ -41,7 +41,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
result: list[MarketplacePluginDeclaration] = [] result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]: for plugin in response.json()["data"]["plugins"]:
try: try:
result.append(MarketplacePluginDeclaration(**plugin)) result.append(MarketplacePluginDeclaration.model_validate(plugin))
except Exception: except Exception:
pass pass

View File

@@ -20,7 +20,7 @@ from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -357,14 +357,16 @@ class IndexingRunner:
raise ValueError("no notion import info found") raise ValueError("no notion import info found")
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION.value,
notion_info={ notion_info=NotionInfo.model_validate(
"credential_id": data_source_info["credential_id"], {
"notion_workspace_id": data_source_info["notion_workspace_id"], "credential_id": data_source_info["credential_id"],
"notion_obj_id": data_source_info["notion_page_id"], "notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_page_type": data_source_info["type"], "notion_obj_id": data_source_info["notion_page_id"],
"document": dataset_document, "notion_page_type": data_source_info["type"],
"tenant_id": dataset_document.tenant_id, "document": dataset_document,
}, "tenant_id": dataset_document.tenant_id,
}
),
document_model=dataset_document.doc_form, document_model=dataset_document.doc_form,
) )
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
@@ -378,14 +380,16 @@ class IndexingRunner:
raise ValueError("no website import info found") raise ValueError("no website import info found")
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value, datasource_type=DatasourceType.WEBSITE.value,
website_info={ website_info=WebsiteInfo.model_validate(
"provider": data_source_info["provider"], {
"job_id": data_source_info["job_id"], "provider": data_source_info["provider"],
"tenant_id": dataset_document.tenant_id, "job_id": data_source_info["job_id"],
"url": data_source_info["url"], "tenant_id": dataset_document.tenant_id,
"mode": data_source_info["mode"], "url": data_source_info["url"],
"only_main_content": data_source_info["only_main_content"], "mode": data_source_info["mode"],
}, "only_main_content": data_source_info["only_main_content"],
}
),
document_model=dataset_document.doc_form, document_model=dataset_document.doc_form,
) )
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])

View File

@@ -294,7 +294,7 @@ class ClientSession(
method="completion/complete", method="completion/complete",
params=types.CompleteRequestParams( params=types.CompleteRequestParams(
ref=ref, ref=ref,
argument=types.CompletionArgument(**argument), argument=types.CompletionArgument.model_validate(argument),
), ),
) )
), ),

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel from pydantic import BaseModel, model_validator
class I18nObject(BaseModel): class I18nObject(BaseModel):
@@ -9,7 +9,8 @@ class I18nObject(BaseModel):
zh_Hans: str | None = None zh_Hans: str | None = None
en_US: str en_US: str
def __init__(self, **data): @model_validator(mode="after")
super().__init__(**data) def _(self):
if not self.zh_Hans: if not self.zh_Hans:
self.zh_Hans = self.en_US self.zh_Hans = self.en_US
return self

View File

@@ -74,7 +74,7 @@ class TextPromptMessageContent(PromptMessageContent):
Model class for text prompt message content. Model class for text prompt message content.
""" """
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore
data: str data: str
@@ -95,11 +95,11 @@ class MultiModalPromptMessageContent(PromptMessageContent):
class VideoPromptMessageContent(MultiModalPromptMessageContent): class VideoPromptMessageContent(MultiModalPromptMessageContent):
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore
class AudioPromptMessageContent(MultiModalPromptMessageContent): class AudioPromptMessageContent(MultiModalPromptMessageContent):
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore
class ImagePromptMessageContent(MultiModalPromptMessageContent): class ImagePromptMessageContent(MultiModalPromptMessageContent):
@@ -111,12 +111,12 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
LOW = auto() LOW = auto()
HIGH = auto() HIGH = auto()
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore
detail: DETAIL = DETAIL.LOW detail: DETAIL = DETAIL.LOW
class DocumentPromptMessageContent(MultiModalPromptMessageContent): class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore
PromptMessageContentUnionTypes = Annotated[ PromptMessageContentUnionTypes = Annotated[

View File

@@ -1,7 +1,7 @@
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum, StrEnum, auto from enum import Enum, StrEnum, auto
from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
@@ -46,10 +46,11 @@ class FormOption(BaseModel):
value: str value: str
show_on: list[FormShowOnObject] = [] show_on: list[FormShowOnObject] = []
def __init__(self, **data): @model_validator(mode="after")
super().__init__(**data) def _(self):
if not self.label: if not self.label:
self.label = I18nObject(en_US=self.value) self.label = I18nObject(en_US=self.value)
return self
class CredentialFormSchema(BaseModel): class CredentialFormSchema(BaseModel):

View File

@@ -15,7 +15,7 @@ class GPT2Tokenizer:
use gpt2 tokenizer to get num tokens use gpt2 tokenizer to get num tokens
""" """
_tokenizer = GPT2Tokenizer.get_encoder() _tokenizer = GPT2Tokenizer.get_encoder()
tokens = _tokenizer.encode(text) tokens = _tokenizer.encode(text) # type: ignore
return len(tokens) return len(tokens)
@staticmethod @staticmethod

View File

@@ -269,17 +269,17 @@ class ModelProviderFactory:
} }
if model_type == ModelType.LLM: if model_type == ModelType.LLM:
return LargeLanguageModel(**init_params) # type: ignore return LargeLanguageModel.model_validate(init_params)
elif model_type == ModelType.TEXT_EMBEDDING: elif model_type == ModelType.TEXT_EMBEDDING:
return TextEmbeddingModel(**init_params) # type: ignore return TextEmbeddingModel.model_validate(init_params)
elif model_type == ModelType.RERANK: elif model_type == ModelType.RERANK:
return RerankModel(**init_params) # type: ignore return RerankModel.model_validate(init_params)
elif model_type == ModelType.SPEECH2TEXT: elif model_type == ModelType.SPEECH2TEXT:
return Speech2TextModel(**init_params) # type: ignore return Speech2TextModel.model_validate(init_params)
elif model_type == ModelType.MODERATION: elif model_type == ModelType.MODERATION:
return ModerationModel(**init_params) # type: ignore return ModerationModel.model_validate(init_params)
elif model_type == ModelType.TTS: elif model_type == ModelType.TTS:
return TTSModel(**init_params) # type: ignore return TTSModel.model_validate(init_params)
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
""" """

View File

@@ -196,15 +196,15 @@ def jsonable_encoder(
return encoder(obj) return encoder(obj)
try: try:
data = dict(obj) data = dict(obj) # type: ignore
except Exception as e: except Exception as e:
errors: list[Exception] = [] errors: list[Exception] = []
errors.append(e) errors.append(e)
try: try:
data = vars(obj) data = vars(obj) # type: ignore
except Exception as e: except Exception as e:
errors.append(e) errors.append(e)
raise ValueError(errors) from e raise ValueError(str(errors)) from e
return jsonable_encoder( return jsonable_encoder(
data, data,
by_alias=by_alias, by_alias=by_alias,

View File

@@ -51,7 +51,7 @@ class ApiModeration(Moderation):
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump()) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
return ModerationInputsResult(**result) return ModerationInputsResult.model_validate(result)
return ModerationInputsResult( return ModerationInputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
@@ -67,7 +67,7 @@ class ApiModeration(Moderation):
params = ModerationOutputParams(app_id=self.app_id, text=text) params = ModerationOutputParams(app_id=self.app_id, text=text)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump()) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
return ModerationOutputsResult(**result) return ModerationOutputsResult.model_validate(result)
return ModerationOutputsResult( return ModerationOutputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response

View File

@@ -3,7 +3,8 @@ from dataclasses import dataclass
from typing import Any from typing import Any
from opentelemetry import trace as trace_api from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event, Status, StatusCode from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from pydantic import BaseModel, Field from pydantic import BaseModel, Field

View File

@@ -155,7 +155,10 @@ class OpsTraceManager:
if key in tracing_config: if key in tracing_config:
if "*" in tracing_config[key]: if "*" in tracing_config[key]:
# If the key contains '*', retain the original value from the current config # If the key contains '*', retain the original value from the current config
new_config[key] = current_trace_config.get(key, tracing_config[key]) if current_trace_config:
new_config[key] = current_trace_config.get(key, tracing_config[key])
else:
new_config[key] = tracing_config[key]
else: else:
# Otherwise, encrypt the key # Otherwise, encrypt the key
new_config[key] = encrypt_token(tenant_id, tracing_config[key]) new_config[key] = encrypt_token(tenant_id, tracing_config[key])

View File

@@ -62,7 +62,8 @@ class WeaveDataTrace(BaseTraceInstance):
self, self,
): ):
try: try:
project_url = f"https://wandb.ai/{self.weave_client._project_id()}" project_identifier = f"{self.entity}/{self.project_name}" if self.entity else self.project_name
project_url = f"https://wandb.ai/{project_identifier}"
return project_url return project_url
except Exception as e: except Exception as e:
logger.debug("Weave get run url failed: %s", str(e)) logger.debug("Weave get run url failed: %s", str(e))
@@ -424,7 +425,23 @@ class WeaveDataTrace(BaseTraceInstance):
raise ValueError(f"Weave API check failed: {str(e)}") raise ValueError(f"Weave API check failed: {str(e)}")
def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None): def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None):
call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes) inputs = run_data.inputs
if inputs is None:
inputs = {}
elif not isinstance(inputs, dict):
inputs = {"inputs": str(inputs)}
attributes = run_data.attributes
if attributes is None:
attributes = {}
elif not isinstance(attributes, dict):
attributes = {"attributes": str(attributes)}
call = self.weave_client.create_call(
op=run_data.op,
inputs=inputs,
attributes=attributes,
)
self.calls[run_data.id] = call self.calls[run_data.id] = call
if parent_run_id: if parent_run_id:
self.calls[run_data.id].parent_id = parent_run_id self.calls[run_data.id].parent_id = parent_run_id
@@ -432,6 +449,7 @@ class WeaveDataTrace(BaseTraceInstance):
def finish_call(self, run_data: WeaveTraceModel): def finish_call(self, run_data: WeaveTraceModel):
call = self.calls.get(run_data.id) call = self.calls.get(run_data.id)
if call: if call:
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception) exception = Exception(run_data.exception) if run_data.exception else None
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=exception)
else: else:
raise ValueError(f"Call with id {run_data.id} not found") raise ValueError(f"Call with id {run_data.id} not found")

View File

@@ -84,15 +84,15 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
for i in range(len(v)): for i in range(len(v)):
if v[i]["role"] == PromptMessageRole.USER.value: if v[i]["role"] == PromptMessageRole.USER.value:
v[i] = UserPromptMessage(**v[i]) v[i] = UserPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.ASSISTANT.value: elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
v[i] = AssistantPromptMessage(**v[i]) v[i] = AssistantPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.SYSTEM.value: elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
v[i] = SystemPromptMessage(**v[i]) v[i] = SystemPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.TOOL.value: elif v[i]["role"] == PromptMessageRole.TOOL.value:
v[i] = ToolPromptMessage(**v[i]) v[i] = ToolPromptMessage.model_validate(v[i])
else: else:
v[i] = PromptMessage(**v[i]) v[i] = PromptMessage.model_validate(v[i])
return v return v

View File

@@ -2,11 +2,10 @@ import inspect
import json import json
import logging import logging
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from typing import TypeVar from typing import Any, TypeVar
import requests import httpx
from pydantic import BaseModel from pydantic import BaseModel
from requests.exceptions import HTTPError
from yarl import URL from yarl import URL
from configs import dify_config from configs import dify_config
@@ -47,29 +46,56 @@ class BasePluginClient:
data: bytes | dict | str | None = None, data: bytes | dict | str | None = None,
params: dict | None = None, params: dict | None = None,
files: dict | None = None, files: dict | None = None,
stream: bool = False, ) -> httpx.Response:
) -> requests.Response:
""" """
Make a request to the plugin daemon inner API. Make a request to the plugin daemon inner API.
""" """
url = plugin_daemon_inner_api_baseurl / path url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files)
headers = headers or {}
headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
headers["Accept-Encoding"] = "gzip, deflate, br"
if headers.get("Content-Type") == "application/json" and isinstance(data, dict): request_kwargs: dict[str, Any] = {
data = json.dumps(data) "method": method,
"url": url,
"headers": headers,
"params": params,
"files": files,
}
if isinstance(prepared_data, dict):
request_kwargs["data"] = prepared_data
elif prepared_data is not None:
request_kwargs["content"] = prepared_data
try: try:
response = requests.request( response = httpx.request(**request_kwargs)
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files except httpx.RequestError:
)
except requests.ConnectionError:
logger.exception("Request to Plugin Daemon Service failed") logger.exception("Request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed") raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
return response return response
def _prepare_request(
self,
path: str,
headers: dict | None,
data: bytes | dict | str | None,
params: dict | None,
files: dict | None,
) -> tuple[str, dict, bytes | dict | str | None, dict | None, dict | None]:
url = plugin_daemon_inner_api_baseurl / path
prepared_headers = dict(headers or {})
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
prepared_data: bytes | dict | str | None = (
data if isinstance(data, (bytes, str, dict)) or data is None else None
)
if isinstance(data, dict):
if prepared_headers.get("Content-Type") == "application/json":
prepared_data = json.dumps(data)
else:
prepared_data = data
return str(url), prepared_headers, prepared_data, params, files
def _stream_request( def _stream_request(
self, self,
method: str, method: str,
@@ -78,23 +104,44 @@ class BasePluginClient:
headers: dict | None = None, headers: dict | None = None,
data: bytes | dict | None = None, data: bytes | dict | None = None,
files: dict | None = None, files: dict | None = None,
) -> Generator[bytes, None, None]: ) -> Generator[str, None, None]:
""" """
Make a stream request to the plugin daemon inner API Make a stream request to the plugin daemon inner API
""" """
response = self._request(method, path, headers, data, params, files, stream=True) url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files)
for line in response.iter_lines(chunk_size=1024 * 8):
line = line.decode("utf-8").strip() stream_kwargs: dict[str, Any] = {
if line.startswith("data:"): "method": method,
line = line[5:].strip() "url": url,
if line: "headers": headers,
yield line "params": params,
"files": files,
}
if isinstance(prepared_data, dict):
stream_kwargs["data"] = prepared_data
elif prepared_data is not None:
stream_kwargs["content"] = prepared_data
try:
with httpx.stream(**stream_kwargs) as response:
for raw_line in response.iter_lines():
if raw_line is None:
continue
line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line
line = line.strip()
if line.startswith("data:"):
line = line[5:].strip()
if line:
yield line
except httpx.RequestError:
logger.exception("Stream request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
def _stream_request_with_model( def _stream_request_with_model(
self, self,
method: str, method: str,
path: str, path: str,
type: type[T], type_: type[T],
headers: dict | None = None, headers: dict | None = None,
data: bytes | dict | None = None, data: bytes | dict | None = None,
params: dict | None = None, params: dict | None = None,
@@ -104,13 +151,13 @@ class BasePluginClient:
Make a stream request to the plugin daemon inner API and yield the response as a model. Make a stream request to the plugin daemon inner API and yield the response as a model.
""" """
for line in self._stream_request(method, path, params, headers, data, files): for line in self._stream_request(method, path, params, headers, data, files):
yield type(**json.loads(line)) # type: ignore yield type_(**json.loads(line)) # type: ignore
def _request_with_model( def _request_with_model(
self, self,
method: str, method: str,
path: str, path: str,
type: type[T], type_: type[T],
headers: dict | None = None, headers: dict | None = None,
data: bytes | None = None, data: bytes | None = None,
params: dict | None = None, params: dict | None = None,
@@ -120,13 +167,13 @@ class BasePluginClient:
Make a request to the plugin daemon inner API and return the response as a model. Make a request to the plugin daemon inner API and return the response as a model.
""" """
response = self._request(method, path, headers, data, params, files) response = self._request(method, path, headers, data, params, files)
return type(**response.json()) # type: ignore return type_(**response.json()) # type: ignore
def _request_with_plugin_daemon_response( def _request_with_plugin_daemon_response(
self, self,
method: str, method: str,
path: str, path: str,
type: type[T], type_: type[T],
headers: dict | None = None, headers: dict | None = None,
data: bytes | dict | None = None, data: bytes | dict | None = None,
params: dict | None = None, params: dict | None = None,
@@ -139,23 +186,23 @@ class BasePluginClient:
try: try:
response = self._request(method, path, headers, data, params, files) response = self._request(method, path, headers, data, params, files)
response.raise_for_status() response.raise_for_status()
except HTTPError as e: except httpx.HTTPStatusError as e:
msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}" logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path)
logger.exception(msg)
raise e raise e
except Exception as e: except Exception as e:
msg = f"Failed to request plugin daemon, url: {path}" msg = f"Failed to request plugin daemon, url: {path}"
logger.exception(msg) logger.exception("Failed to request plugin daemon, url: %s", path)
raise ValueError(msg) from e raise ValueError(msg) from e
try: try:
json_response = response.json() json_response = response.json()
if transformer: if transformer:
json_response = transformer(json_response) json_response = transformer(json_response)
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore # https://stackoverflow.com/questions/59634937/variable-foo-class-is-not-valid-as-type-but-why
rep = PluginDaemonBasicResponse[type_].model_validate(json_response) # type: ignore
except Exception: except Exception:
msg = ( msg = (
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}]," f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type_.__name__)}],"
f" url: {path}" f" url: {path}"
) )
logger.exception(msg) logger.exception(msg)
@@ -163,7 +210,7 @@ class BasePluginClient:
if rep.code != 0: if rep.code != 0:
try: try:
error = PluginDaemonError(**json.loads(rep.message)) error = PluginDaemonError.model_validate(json.loads(rep.message))
except Exception: except Exception:
raise ValueError(f"{rep.message}, code: {rep.code}") raise ValueError(f"{rep.message}, code: {rep.code}")
@@ -178,7 +225,7 @@ class BasePluginClient:
self, self,
method: str, method: str,
path: str, path: str,
type: type[T], type_: type[T],
headers: dict | None = None, headers: dict | None = None,
data: bytes | dict | None = None, data: bytes | dict | None = None,
params: dict | None = None, params: dict | None = None,
@@ -189,7 +236,7 @@ class BasePluginClient:
""" """
for line in self._stream_request(method, path, params, headers, data, files): for line in self._stream_request(method, path, params, headers, data, files):
try: try:
rep = PluginDaemonBasicResponse[type].model_validate_json(line) # type: ignore rep = PluginDaemonBasicResponse[type_].model_validate_json(line) # type: ignore
except (ValueError, TypeError): except (ValueError, TypeError):
# TODO modify this when line_data has code and message # TODO modify this when line_data has code and message
try: try:
@@ -204,7 +251,7 @@ class BasePluginClient:
if rep.code != 0: if rep.code != 0:
if rep.code == -500: if rep.code == -500:
try: try:
error = PluginDaemonError(**json.loads(rep.message)) error = PluginDaemonError.model_validate(json.loads(rep.message))
except Exception: except Exception:
raise PluginDaemonInnerError(code=rep.code, message=rep.message) raise PluginDaemonInnerError(code=rep.code, message=rep.message)

View File

@@ -46,7 +46,9 @@ class PluginDatasourceManager(BasePluginClient):
params={"page": 1, "page_size": 256}, params={"page": 1, "page_size": 256},
transformer=transformer, transformer=transformer,
) )
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) local_file_datasource_provider = PluginDatasourceProviderEntity.model_validate(
self._get_local_file_datasource_provider()
)
for provider in response: for provider in response:
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider) ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
@@ -104,7 +106,7 @@ class PluginDatasourceManager(BasePluginClient):
Fetch datasource provider for the given tenant and plugin. Fetch datasource provider for the given tenant and plugin.
""" """
if provider_id == "langgenius/file/file": if provider_id == "langgenius/file/file":
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) return PluginDatasourceProviderEntity.model_validate(self._get_local_file_datasource_provider())
tool_provider_id = DatasourceProviderID(provider_id) tool_provider_id = DatasourceProviderID(provider_id)

View File

@@ -162,7 +162,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
method="POST", method="POST",
path=f"plugin/{tenant_id}/dispatch/llm/invoke", path=f"plugin/{tenant_id}/dispatch/llm/invoke",
type=LLMResultChunk, type_=LLMResultChunk,
data=jsonable_encoder( data=jsonable_encoder(
{ {
"user_id": user_id, "user_id": user_id,
@@ -208,7 +208,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
method="POST", method="POST",
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens", path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
type=PluginLLMNumTokensResponse, type_=PluginLLMNumTokensResponse,
data=jsonable_encoder( data=jsonable_encoder(
{ {
"user_id": user_id, "user_id": user_id,
@@ -250,7 +250,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
method="POST", method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke", path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
type=TextEmbeddingResult, type_=TextEmbeddingResult,
data=jsonable_encoder( data=jsonable_encoder(
{ {
"user_id": user_id, "user_id": user_id,
@@ -291,7 +291,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
method="POST", method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens", path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
type=PluginTextEmbeddingNumTokensResponse, type_=PluginTextEmbeddingNumTokensResponse,
data=jsonable_encoder( data=jsonable_encoder(
{ {
"user_id": user_id, "user_id": user_id,
@@ -334,7 +334,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
method="POST", method="POST",
path=f"plugin/{tenant_id}/dispatch/rerank/invoke", path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
type=RerankResult, type_=RerankResult,
data=jsonable_encoder( data=jsonable_encoder(
{ {
"user_id": user_id, "user_id": user_id,
@@ -378,7 +378,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
method="POST", method="POST",
path=f"plugin/{tenant_id}/dispatch/tts/invoke", path=f"plugin/{tenant_id}/dispatch/tts/invoke",
type=PluginStringResultResponse, type_=PluginStringResultResponse,
data=jsonable_encoder( data=jsonable_encoder(
{ {
"user_id": user_id, "user_id": user_id,
@@ -422,7 +422,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
method="POST", method="POST",
path=f"plugin/{tenant_id}/dispatch/tts/model/voices", path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
type=PluginVoicesResponse, type_=PluginVoicesResponse,
data=jsonable_encoder( data=jsonable_encoder(
{ {
"user_id": user_id, "user_id": user_id,
@@ -466,7 +466,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
method="POST", method="POST",
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke", path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
type=PluginStringResultResponse, type_=PluginStringResultResponse,
data=jsonable_encoder( data=jsonable_encoder(
{ {
"user_id": user_id, "user_id": user_id,
@@ -506,7 +506,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
method="POST", method="POST",
path=f"plugin/{tenant_id}/dispatch/moderation/invoke", path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
type=PluginBasicBooleanResponse, type_=PluginBasicBooleanResponse,
data=jsonable_encoder( data=jsonable_encoder(
{ {
"user_id": user_id, "user_id": user_id,

View File

@@ -1,6 +1,6 @@
from collections.abc import Generator from collections.abc import Generator
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TypeVar, Union, cast from typing import TypeVar, Union
from core.agent.entities import AgentInvokeMessage from core.agent.entities import AgentInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
@@ -87,7 +87,8 @@ def merge_blob_chunks(
), ),
meta=resp.meta, meta=resp.meta,
) )
yield cast(MessageType, merged_message) assert isinstance(merged_message, (ToolInvokeMessage, AgentInvokeMessage))
yield merged_message # type: ignore
# Clean up the buffer # Clean up the buffer
del files[chunk_id] del files[chunk_id]
else: else:

View File

@@ -106,7 +106,9 @@ class RetrievalService:
if exceptions: if exceptions:
raise ValueError(";\n".join(exceptions)) raise ValueError(";\n".join(exceptions))
# Deduplicate documents for hybrid search to avoid duplicate chunks
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
all_documents = cls._deduplicate_documents(all_documents)
data_post_processor = DataPostProcessor( data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
) )
@@ -132,7 +134,7 @@ class RetrievalService:
if not dataset: if not dataset:
return [] return []
metadata_condition = ( metadata_condition = (
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
) )
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id, dataset.tenant_id,
@@ -143,6 +145,40 @@ class RetrievalService:
) )
return all_documents return all_documents
@classmethod
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
"""Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search."""
if not documents:
return documents
unique_documents = []
seen_doc_ids = set()
for document in documents:
# For dify provider documents, use doc_id for deduplication
if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata:
doc_id = document.metadata["doc_id"]
if doc_id not in seen_doc_ids:
seen_doc_ids.add(doc_id)
unique_documents.append(document)
# If duplicate, keep the one with higher score
elif "score" in document.metadata:
# Find existing document with same doc_id and compare scores
for i, existing_doc in enumerate(unique_documents):
if (
existing_doc.metadata
and existing_doc.metadata.get("doc_id") == doc_id
and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0)
):
unique_documents[i] = document
break
else:
# For non-dify documents, use content-based deduplication
if document not in unique_documents:
unique_documents.append(document)
return unique_documents
@classmethod @classmethod
def _get_dataset(cls, dataset_id: str) -> Dataset | None: def _get_dataset(cls, dataset_id: str) -> Dataset | None:
with Session(db.engine) as session: with Session(db.engine) as session:

View File

@@ -4,7 +4,7 @@ import math
from typing import Any, cast from typing import Any, cast
from urllib.parse import urlparse from urllib.parse import urlparse
import requests from elasticsearch import ConnectionError as ElasticsearchConnectionError
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from flask import current_app from flask import current_app
from packaging.version import parse as parse_version from packaging.version import parse as parse_version
@@ -138,7 +138,7 @@ class ElasticSearchVector(BaseVector):
if not client.ping(): if not client.ping():
raise ConnectionError("Failed to connect to Elasticsearch") raise ConnectionError("Failed to connect to Elasticsearch")
except requests.ConnectionError as e: except ElasticsearchConnectionError as e:
raise ConnectionError(f"Vector database connection error: {str(e)}") raise ConnectionError(f"Vector database connection error: {str(e)}")
except Exception as e: except Exception as e:
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}") raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")

View File

@@ -5,9 +5,10 @@ from collections.abc import Generator, Iterable, Sequence
from itertools import islice from itertools import islice
from typing import TYPE_CHECKING, Any, Union from typing import TYPE_CHECKING, Any, Union
import httpx
import qdrant_client import qdrant_client
import requests
from flask import current_app from flask import current_app
from httpx import DigestAuth
from pydantic import BaseModel from pydantic import BaseModel
from qdrant_client.http import models as rest from qdrant_client.http import models as rest
from qdrant_client.http.models import ( from qdrant_client.http.models import (
@@ -19,7 +20,6 @@ from qdrant_client.http.models import (
TokenizerType, TokenizerType,
) )
from qdrant_client.local.qdrant_local import QdrantLocal from qdrant_client.local.qdrant_local import QdrantLocal
from requests.auth import HTTPDigestAuth
from sqlalchemy import select from sqlalchemy import select
from configs import dify_config from configs import dify_config
@@ -504,10 +504,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
} }
cluster_data = {"displayName": display_name, "region": region_object, "labels": labels} cluster_data = {"displayName": display_name, "region": region_object, "labels": labels}
response = requests.post( response = httpx.post(
f"{tidb_config.api_url}/clusters", f"{tidb_config.api_url}/clusters",
json=cluster_data, json=cluster_data,
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), auth=DigestAuth(tidb_config.public_key, tidb_config.private_key),
) )
if response.status_code == 200: if response.status_code == 200:
@@ -527,10 +527,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
body = {"password": new_password} body = {"password": new_password}
response = requests.put( response = httpx.put(
f"{tidb_config.api_url}/clusters/{cluster_id}/password", f"{tidb_config.api_url}/clusters/{cluster_id}/password",
json=body, json=body,
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), auth=DigestAuth(tidb_config.public_key, tidb_config.private_key),
) )
if response.status_code == 200: if response.status_code == 200:

View File

@@ -2,8 +2,8 @@ import time
import uuid import uuid
from collections.abc import Sequence from collections.abc import Sequence
import requests import httpx
from requests.auth import HTTPDigestAuth from httpx import DigestAuth
from configs import dify_config from configs import dify_config
from extensions.ext_database import db from extensions.ext_database import db
@@ -49,7 +49,7 @@ class TidbService:
"rootPassword": password, "rootPassword": password,
} }
response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key)) response = httpx.post(f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key))
if response.status_code == 200: if response.status_code == 200:
response_data = response.json() response_data = response.json()
@@ -83,7 +83,7 @@ class TidbService:
:return: The response from the API. :return: The response from the API.
""" """
response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) response = httpx.delete(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
if response.status_code == 200: if response.status_code == 200:
return response.json() return response.json()
@@ -102,7 +102,7 @@ class TidbService:
:return: The response from the API. :return: The response from the API.
""" """
response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) response = httpx.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
if response.status_code == 200: if response.status_code == 200:
return response.json() return response.json()
@@ -127,10 +127,10 @@ class TidbService:
body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []} body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []}
response = requests.patch( response = httpx.patch(
f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}", f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}",
json=body, json=body,
auth=HTTPDigestAuth(public_key, private_key), auth=DigestAuth(public_key, private_key),
) )
if response.status_code == 200: if response.status_code == 200:
@@ -161,9 +161,7 @@ class TidbService:
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list] cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "BASIC"} params = {"clusterIds": cluster_ids, "view": "BASIC"}
response = requests.get( response = httpx.get(f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key))
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
)
if response.status_code == 200: if response.status_code == 200:
response_data = response.json() response_data = response.json()
@@ -224,8 +222,8 @@ class TidbService:
clusters.append(cluster_data) clusters.append(cluster_data)
request_body = {"requests": clusters} request_body = {"requests": clusters}
response = requests.post( response = httpx.post(
f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key) f"{api_url}/clusters:batchCreate", json=request_body, auth=DigestAuth(public_key, private_key)
) )
if response.status_code == 200: if response.status_code == 200:

View File

@@ -2,7 +2,6 @@ import datetime
import json import json
from typing import Any from typing import Any
import requests
import weaviate # type: ignore import weaviate # type: ignore
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@@ -45,8 +44,8 @@ class WeaviateVector(BaseVector):
client = weaviate.Client( client = weaviate.Client(
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
) )
except requests.ConnectionError: except Exception as exc:
raise ConnectionError("Vector database connection error") raise ConnectionError("Vector database connection error") from exc
client.batch.configure( client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching # `batch_size` takes an `int` value to enable auto-batching

View File

@@ -42,6 +42,10 @@ class CacheEmbedding(Embeddings):
text_embeddings[i] = embedding.get_embedding() text_embeddings[i] = embedding.get_embedding()
else: else:
embedding_queue_indices.append(i) embedding_queue_indices.append(i)
# release database connection, because embedding may take a long time
db.session.close()
if embedding_queue_indices: if embedding_queue_indices:
embedding_queue_texts = [texts[i] for i in embedding_queue_indices] embedding_queue_texts = [texts[i] for i in embedding_queue_indices]
embedding_queue_embeddings = [] embedding_queue_embeddings = []

View File

@@ -17,9 +17,6 @@ class NotionInfo(BaseModel):
tenant_id: str tenant_id: str
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data):
super().__init__(**data)
class WebsiteInfo(BaseModel): class WebsiteInfo(BaseModel):
""" """
@@ -47,6 +44,3 @@ class ExtractSetting(BaseModel):
website_info: WebsiteInfo | None = None website_info: WebsiteInfo | None = None
document_model: str | None = None document_model: str | None = None
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data):
super().__init__(**data)

View File

@@ -2,7 +2,7 @@ import json
import time import time
from typing import Any, cast from typing import Any, cast
import requests import httpx
from extensions.ext_storage import storage from extensions.ext_storage import storage
@@ -104,18 +104,18 @@ class FirecrawlApp:
def _prepare_headers(self) -> dict[str, Any]: def _prepare_headers(self) -> dict[str, Any]:
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> requests.Response: def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
for attempt in range(retries): for attempt in range(retries):
response = requests.post(url, headers=headers, json=data) response = httpx.post(url, headers=headers, json=data)
if response.status_code == 502: if response.status_code == 502:
time.sleep(backoff_factor * (2**attempt)) time.sleep(backoff_factor * (2**attempt))
else: else:
return response return response
return response return response
def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> requests.Response: def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
for attempt in range(retries): for attempt in range(retries):
response = requests.get(url, headers=headers) response = httpx.get(url, headers=headers)
if response.status_code == 502: if response.status_code == 502:
time.sleep(backoff_factor * (2**attempt)) time.sleep(backoff_factor * (2**attempt))
else: else:

Some files were not shown because too many files have changed in this diff Show More