diff --git a/api/controllers/common/schema.py b/api/controllers/common/schema.py new file mode 100644 index 0000000000..e0896a8dc2 --- /dev/null +++ b/api/controllers/common/schema.py @@ -0,0 +1,26 @@ +"""Helpers for registering Pydantic models with Flask-RESTX namespaces.""" + +from flask_restx import Namespace +from pydantic import BaseModel + +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None: + """Register a single BaseModel with a namespace for Swagger documentation.""" + + namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + +def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None: + """Register multiple BaseModels with a namespace.""" + + for model in models: + register_schema_model(namespace, model) + + +__all__ = [ + "DEFAULT_REF_TEMPLATE_SWAGGER_2_0", + "register_schema_model", + "register_schema_models", +] diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index bfe91bbb61..62e997dae2 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -31,7 +31,6 @@ from fields.app_fields import ( from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict from libs.helper import AppIconUrlField, TimestampField from libs.login import current_account_with_tenant, login_required -from libs.validators import validate_description_length from models import App, Workflow from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService @@ -76,51 +75,30 @@ class AppListQuery(BaseModel): class CreateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") - description: str | None = Field(default=None, description="App description (max 400 chars)") + description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode") icon_type: str | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") - @field_validator("description") - @classmethod - def validate_description(cls, value: str | None) -> str | None: - if value is None: - return value - return validate_description_length(value) - class UpdateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") - description: str | None = Field(default=None, description="App description (max 400 chars)") + description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) icon_type: str | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon") max_active_requests: int | None = Field(default=None, description="Maximum active requests") - @field_validator("description") - @classmethod - def validate_description(cls, value: str | None) -> str | None: - if value is None: - return value - return validate_description_length(value) - class CopyAppPayload(BaseModel): name: str | None = Field(default=None, description="Name for the copied app") - description: str | None = Field(default=None, description="Description for the copied app") + description: str | None = Field(default=None, description="Description for the copied app", max_length=400) icon_type: str | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") - @field_validator("description") - @classmethod - def validate_description(cls, value: str | None) -> str | None: - if value is None: - return value - return validate_description_length(value) - class AppExportQuery(BaseModel): include_secret: bool = Field(default=False, description="Include secrets in export") diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index ef66053075..01f268d94d 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,15 +1,15 @@ import json from collections.abc import Generator -from typing import cast +from typing import Any, cast from flask import request -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.common.schema import register_schema_model from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.indexing_runner import IndexingRunner @@ -25,6 +25,19 @@ from services.dataset_service import DatasetService, DocumentService from services.datasource_provider_service import DatasourceProviderService from tasks.document_indexing_sync_task import document_indexing_sync_task +from .. import console_ns +from ..wraps import account_initialization_required, setup_required + + +class NotionEstimatePayload(BaseModel): + notion_info_list: list[dict[str, Any]] + process_rule: dict[str, Any] + doc_form: str = Field(default="text_model") + doc_language: str = Field(default="English") + + +register_schema_model(console_ns, NotionEstimatePayload) + @console_ns.route( "/data-source/integrates", @@ -243,20 +256,15 @@ class DataSourceNotionApi(Resource): @setup_required @login_required @account_initialization_required + @console_ns.expect(console_ns.models[NotionEstimatePayload.__name__]) def post(self): _, current_tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") - .add_argument("process_rule", type=dict, required=True, nullable=True, location="json") - .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") - ) - args = parser.parse_args() + payload = NotionEstimatePayload.model_validate(console_ns.payload or {}) + args = payload.model_dump() # validate args DocumentService.estimate_args_validate(args) - notion_info_list = args["notion_info_list"] + notion_info_list = payload.notion_info_list extract_settings = [] for notion_info in notion_info_list: workspace_id = notion_info["workspace_id"] diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 45bc1fa694..1fad8abd52 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,12 +1,14 @@ from typing import Any, cast from flask import request -from flask_restx import Resource, fields, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound import services from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.apikey import ( api_key_item_model, @@ -48,7 +50,6 @@ from fields.dataset_fields import ( ) from fields.document_fields import document_status_fields from libs.login import current_account_with_tenant, login_required -from libs.validators import validate_description_length from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermissionEnum from models.provider_ids import ModelProviderID @@ -107,10 +108,74 @@ related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_mode related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy) -def _validate_name(name: str) -> str: - if not name or len(name) < 1 or len(name) > 40: - raise ValueError("Name must be between 1 to 40 characters.") - return name +def _validate_indexing_technique(value: str | None) -> str | None: + if value is None: + return value + if value not in Dataset.INDEXING_TECHNIQUE_LIST: + raise ValueError("Invalid indexing technique.") + return value + + +class DatasetCreatePayload(BaseModel): + name: str = Field(..., min_length=1, max_length=40) + description: str = Field("", max_length=400) + indexing_technique: str | None = None + permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME + provider: str = "vendor" + external_knowledge_api_id: str | None = None + external_knowledge_id: str | None = None + + @field_validator("indexing_technique") + @classmethod + def validate_indexing(cls, value: str | None) -> str | None: + return _validate_indexing_technique(value) + + @field_validator("provider") + @classmethod + def validate_provider(cls, value: str) -> str: + if value not in Dataset.PROVIDER_LIST: + raise ValueError("Invalid provider.") + return value + + +class DatasetUpdatePayload(BaseModel): + name: str | None = Field(None, min_length=1, max_length=40) + description: str | None = Field(None, max_length=400) + permission: DatasetPermissionEnum | None = None + indexing_technique: str | None = None + embedding_model: str | None = None + embedding_model_provider: str | None = None + retrieval_model: dict[str, Any] | None = None + partial_member_list: list[str] | None = None + external_retrieval_model: dict[str, Any] | None = None + external_knowledge_id: str | None = None + external_knowledge_api_id: str | None = None + icon_info: dict[str, Any] | None = None + + @field_validator("indexing_technique") + @classmethod + def validate_indexing(cls, value: str | None) -> str | None: + return _validate_indexing_technique(value) + + +class IndexingEstimatePayload(BaseModel): + info_list: dict[str, Any] + process_rule: dict[str, Any] + indexing_technique: str + doc_form: str = "text_model" + dataset_id: str | None = None + doc_language: str = "English" + + @field_validator("indexing_technique") + @classmethod + def validate_indexing(cls, value: str) -> str: + result = _validate_indexing_technique(value) + if result is None: + raise ValueError("indexing_technique is required.") + return result + + +register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload) def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]: @@ -255,20 +320,7 @@ class DatasetListApi(Resource): @console_ns.doc("create_dataset") @console_ns.doc(description="Create a new dataset") - @console_ns.expect( - console_ns.model( - "CreateDatasetRequest", - { - "name": fields.String(required=True, description="Dataset name (1-40 characters)"), - "description": fields.String(description="Dataset description (max 400 characters)"), - "indexing_technique": fields.String(description="Indexing technique"), - "permission": fields.String(description="Dataset permission"), - "provider": fields.String(description="Provider"), - "external_knowledge_api_id": fields.String(description="External knowledge API ID"), - "external_knowledge_id": fields.String(description="External knowledge ID"), - }, - ) - ) + @console_ns.expect(console_ns.models[DatasetCreatePayload.__name__]) @console_ns.response(201, "Dataset created successfully") @console_ns.response(400, "Invalid request parameters") @setup_required @@ -276,52 +328,7 @@ class DatasetListApi(Resource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") def post(self): - parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - .add_argument( - "description", - type=validate_description_length, - nullable=True, - required=False, - default="", - ) - .add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help="Invalid indexing technique.", - ) - .add_argument( - "external_knowledge_api_id", - type=str, - nullable=True, - required=False, - ) - .add_argument( - "provider", - type=str, - nullable=True, - choices=Dataset.PROVIDER_LIST, - required=False, - default="vendor", - ) - .add_argument( - "external_knowledge_id", - type=str, - nullable=True, - required=False, - ) - ) - args = parser.parse_args() + payload = DatasetCreatePayload.model_validate(console_ns.payload or {}) current_user, current_tenant_id = current_account_with_tenant() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator @@ -331,14 +338,14 @@ class DatasetListApi(Resource): try: dataset = DatasetService.create_empty_dataset( tenant_id=current_tenant_id, - name=args["name"], - description=args["description"], - indexing_technique=args["indexing_technique"], + name=payload.name, + description=payload.description, + indexing_technique=payload.indexing_technique, account=current_user, - permission=DatasetPermissionEnum.ONLY_ME, - provider=args["provider"], - external_knowledge_api_id=args["external_knowledge_api_id"], - external_knowledge_id=args["external_knowledge_id"], + permission=payload.permission or DatasetPermissionEnum.ONLY_ME, + provider=payload.provider, + external_knowledge_api_id=payload.external_knowledge_api_id, + external_knowledge_id=payload.external_knowledge_id, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -399,18 +406,7 @@ class DatasetApi(Resource): @console_ns.doc("update_dataset") @console_ns.doc(description="Update dataset details") - @console_ns.expect( - console_ns.model( - "UpdateDatasetRequest", - { - "name": fields.String(description="Dataset name"), - "description": fields.String(description="Dataset description"), - "permission": fields.String(description="Dataset permission"), - "indexing_technique": fields.String(description="Indexing technique"), - "external_retrieval_model": fields.Raw(description="External retrieval model settings"), - }, - ) - ) + @console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__]) @console_ns.response(200, "Dataset updated successfully", dataset_detail_model) @console_ns.response(404, "Dataset not found") @console_ns.response(403, "Permission denied") @@ -424,93 +420,26 @@ class DatasetApi(Resource): if dataset is None: raise NotFound("Dataset not found.") - parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - .add_argument("description", location="json", store_missing=False, type=validate_description_length) - .add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help="Invalid indexing technique.", - ) - .add_argument( - "permission", - type=str, - location="json", - choices=( - DatasetPermissionEnum.ONLY_ME, - DatasetPermissionEnum.ALL_TEAM, - DatasetPermissionEnum.PARTIAL_TEAM, - ), - help="Invalid permission.", - ) - .add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") - .add_argument( - "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." - ) - .add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") - .add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") - .add_argument( - "external_retrieval_model", - type=dict, - required=False, - nullable=True, - location="json", - help="Invalid external retrieval model.", - ) - .add_argument( - "external_knowledge_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge id.", - ) - .add_argument( - "external_knowledge_api_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge api id.", - ) - .add_argument( - "icon_info", - type=dict, - required=False, - nullable=True, - location="json", - help="Invalid icon info.", - ) - ) - args = parser.parse_args() - data = request.get_json() + payload = DatasetUpdatePayload.model_validate(console_ns.payload or {}) + payload_data = payload.model_dump(exclude_unset=True) current_user, current_tenant_id = current_account_with_tenant() # check embedding model setting if ( - data.get("indexing_technique") == "high_quality" - and data.get("embedding_model_provider") is not None - and data.get("embedding_model") is not None + payload.indexing_technique == "high_quality" + and payload.embedding_model_provider is not None + and payload.embedding_model is not None ): DatasetService.check_embedding_model_setting( - dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") + dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model ) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( - current_user, dataset, data.get("permission"), data.get("partial_member_list") + current_user, dataset, payload.permission, payload.partial_member_list ) - dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) + dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user) if dataset is None: raise NotFound("Dataset not found.") @@ -518,15 +447,10 @@ class DatasetApi(Resource): result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) tenant_id = current_tenant_id - if data.get("partial_member_list") and data.get("permission") == "partial_members": - DatasetPermissionService.update_partial_member_list( - tenant_id, dataset_id_str, data.get("partial_member_list") - ) + if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM: + DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list) # clear partial member list when permission is only_me or all_team_members - elif ( - data.get("permission") == DatasetPermissionEnum.ONLY_ME - or data.get("permission") == DatasetPermissionEnum.ALL_TEAM - ): + elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}: DatasetPermissionService.clear_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) @@ -615,24 +539,10 @@ class DatasetIndexingEstimateApi(Resource): @setup_required @login_required @account_initialization_required + @console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__]) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("info_list", type=dict, required=True, nullable=True, location="json") - .add_argument("process_rule", type=dict, required=True, nullable=True, location="json") - .add_argument( - "indexing_technique", - type=str, - required=True, - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - location="json", - ) - .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - .add_argument("dataset_id", type=str, required=False, nullable=False, location="json") - .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") - ) - args = parser.parse_args() + payload = IndexingEstimatePayload.model_validate(console_ns.payload or {}) + args = payload.model_dump() _, current_tenant_id = current_account_with_tenant() # validate args DocumentService.estimate_args_validate(args) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 8ac285e9f8..2520111281 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -6,31 +6,14 @@ from typing import Literal, cast import sqlalchemy as sa from flask import request -from flask_restx import Resource, fields, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with +from pydantic import BaseModel from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns -from controllers.console.app.error import ( - ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, - ProviderQuotaExceededError, -) -from controllers.console.datasets.error import ( - ArchivedDocumentImmutableError, - DocumentAlreadyFinishedError, - DocumentIndexingError, - IndexingEstimateError, - InvalidActionError, - InvalidMetadataError, -) -from controllers.console.wraps import ( - account_initialization_required, - cloud_edition_billing_rate_limit_check, - cloud_edition_billing_resource_check, - setup_required, -) from core.errors.error import ( LLMBadRequestError, ModelCurrentlyNotSupportError, @@ -55,10 +38,30 @@ from fields.document_fields import ( ) from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required -from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile +from models import DatasetProcessRule, Document, DocumentSegment, UploadFile from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService -from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig +from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel + +from ..app.error import ( + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from ..datasets.error import ( + ArchivedDocumentImmutableError, + DocumentAlreadyFinishedError, + DocumentIndexingError, + IndexingEstimateError, + InvalidActionError, + InvalidMetadataError, +) +from ..wraps import ( + account_initialization_required, + cloud_edition_billing_rate_limit_check, + cloud_edition_billing_resource_check, + setup_required, +) logger = logging.getLogger(__name__) @@ -93,6 +96,24 @@ dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(docume dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) +class DocumentRetryPayload(BaseModel): + document_ids: list[str] + + +class DocumentRenamePayload(BaseModel): + name: str + + +register_schema_models( + console_ns, + KnowledgeConfig, + ProcessRule, + RetrievalModel, + DocumentRetryPayload, + DocumentRenamePayload, +) + + class DocumentResource(Resource): def get_document(self, dataset_id: str, document_id: str) -> Document: current_user, current_tenant_id = current_account_with_tenant() @@ -311,6 +332,7 @@ class DatasetDocumentListApi(Resource): @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") + @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) def post(self, dataset_id): current_user, _ = current_account_with_tenant() dataset_id = str(dataset_id) @@ -329,23 +351,7 @@ class DatasetDocumentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - parser = ( - reqparse.RequestParser() - .add_argument( - "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" - ) - .add_argument("data_source", type=dict, required=False, location="json") - .add_argument("process_rule", type=dict, required=False, location="json") - .add_argument("duplicate", type=bool, default=True, nullable=False, location="json") - .add_argument("original_document_id", type=str, required=False, location="json") - .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - .add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") - .add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") - .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") - ) - args = parser.parse_args() - knowledge_config = KnowledgeConfig.model_validate(args) + knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {}) if not dataset.indexing_technique and not knowledge_config.indexing_technique: raise ValueError("indexing_technique is required.") @@ -391,17 +397,7 @@ class DatasetDocumentListApi(Resource): class DatasetInitApi(Resource): @console_ns.doc("init_dataset") @console_ns.doc(description="Initialize dataset with documents") - @console_ns.expect( - console_ns.model( - "DatasetInitRequest", - { - "upload_file_id": fields.String(required=True, description="Upload file ID"), - "indexing_technique": fields.String(description="Indexing technique"), - "process_rule": fields.Raw(description="Processing rules"), - "data_source": fields.Raw(description="Data source configuration"), - }, - ) - ) + @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) @console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model) @console_ns.response(400, "Invalid request parameters") @setup_required @@ -416,27 +412,7 @@ class DatasetInitApi(Resource): if not current_user.is_dataset_editor: raise Forbidden() - parser = ( - reqparse.RequestParser() - .add_argument( - "indexing_technique", - type=str, - choices=Dataset.INDEXING_TECHNIQUE_LIST, - required=True, - nullable=False, - location="json", - ) - .add_argument("data_source", type=dict, required=True, nullable=True, location="json") - .add_argument("process_rule", type=dict, required=True, nullable=True, location="json") - .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") - .add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") - .add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") - ) - args = parser.parse_args() - - knowledge_config = KnowledgeConfig.model_validate(args) + knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {}) if knowledge_config.indexing_technique == "high_quality": if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") @@ -444,9 +420,9 @@ class DatasetInitApi(Resource): model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_tenant_id, - provider=args["embedding_model_provider"], + provider=knowledge_config.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=args["embedding_model"], + model=knowledge_config.embedding_model, ) except InvokeAuthorizationError: raise ProviderNotInitializeError( @@ -1077,19 +1053,16 @@ class DocumentRetryApi(DocumentResource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") + @console_ns.expect(console_ns.models[DocumentRetryPayload.__name__]) def post(self, dataset_id): """retry document.""" - - parser = reqparse.RequestParser().add_argument( - "document_ids", type=list, required=True, nullable=False, location="json" - ) - args = parser.parse_args() + payload = DocumentRetryPayload.model_validate(console_ns.payload or {}) dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) retry_documents = [] if not dataset: raise NotFound("Dataset not found.") - for document_id in args["document_ids"]: + for document_id in payload.document_ids: try: document_id = str(document_id) @@ -1122,6 +1095,7 @@ class DocumentRenameApi(DocumentResource): @login_required @account_initialization_required @marshal_with(document_fields) + @console_ns.expect(console_ns.models[DocumentRenamePayload.__name__]) def post(self, dataset_id, document_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator current_user, _ = current_account_with_tenant() @@ -1131,11 +1105,10 @@ class DocumentRenameApi(DocumentResource): if not dataset: raise NotFound("Dataset not found.") DatasetService.check_dataset_operator_permission(current_user, dataset) - parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() + payload = DocumentRenamePayload.model_validate(console_ns.payload or {}) try: - document = DocumentService.rename_document(dataset_id, document_id, args["name"]) + document = DocumentService.rename_document(dataset_id, document_id, payload.name) except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 2fe7d42e46..ee390cbfb7 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,11 +1,13 @@ import uuid from flask import request -from flask_restx import Resource, marshal, reqparse +from flask_restx import Resource, marshal +from pydantic import BaseModel, Field from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import ( @@ -36,6 +38,56 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task +class SegmentListQuery(BaseModel): + limit: int = Field(default=20, ge=1, le=100) + status: list[str] = Field(default_factory=list) + hit_count_gte: int | None = None + enabled: str = Field(default="all") + keyword: str | None = None + page: int = Field(default=1, ge=1) + + +class SegmentCreatePayload(BaseModel): + content: str + answer: str | None = None + keywords: list[str] | None = None + + +class SegmentUpdatePayload(BaseModel): + content: str + answer: str | None = None + keywords: list[str] | None = None + regenerate_child_chunks: bool = False + + +class BatchImportPayload(BaseModel): + upload_file_id: str + + +class ChildChunkCreatePayload(BaseModel): + content: str + + +class ChildChunkUpdatePayload(BaseModel): + content: str + + +class ChildChunkBatchUpdatePayload(BaseModel): + chunks: list[ChildChunkUpdateArgs] + + +register_schema_models( + console_ns, + SegmentListQuery, + SegmentCreatePayload, + SegmentUpdatePayload, + BatchImportPayload, + ChildChunkCreatePayload, + ChildChunkUpdatePayload, + ChildChunkBatchUpdatePayload, +) + + @console_ns.route("/datasets//documents//segments") class DatasetDocumentSegmentListApi(Resource): @setup_required @@ -60,23 +112,18 @@ class DatasetDocumentSegmentListApi(Resource): if not document: raise NotFound("Document not found.") - parser = ( - reqparse.RequestParser() - .add_argument("limit", type=int, default=20, location="args") - .add_argument("status", type=str, action="append", default=[], location="args") - .add_argument("hit_count_gte", type=int, default=None, location="args") - .add_argument("enabled", type=str, default="all", location="args") - .add_argument("keyword", type=str, default=None, location="args") - .add_argument("page", type=int, default=1, location="args") + args = SegmentListQuery.model_validate( + { + **request.args.to_dict(), + "status": request.args.getlist("status"), + } ) - args = parser.parse_args() - - page = args["page"] - limit = min(args["limit"], 100) - status_list = args["status"] - hit_count_gte = args["hit_count_gte"] - keyword = args["keyword"] + page = args.page + limit = min(args.limit, 100) + status_list = args.status + hit_count_gte = args.hit_count_gte + keyword = args.keyword query = ( select(DocumentSegment) @@ -96,10 +143,10 @@ class DatasetDocumentSegmentListApi(Resource): if keyword: query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) - if args["enabled"].lower() != "all": - if args["enabled"].lower() == "true": + if args.enabled.lower() != "all": + if args.enabled.lower() == "true": query = query.where(DocumentSegment.enabled == True) - elif args["enabled"].lower() == "false": + elif args.enabled.lower() == "false": query = query.where(DocumentSegment.enabled == False) segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @@ -210,6 +257,7 @@ class DatasetDocumentSegmentAddApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") + @console_ns.expect(console_ns.models[SegmentCreatePayload.__name__]) def post(self, dataset_id, document_id): current_user, current_tenant_id = current_account_with_tenant() @@ -246,15 +294,10 @@ class DatasetDocumentSegmentAddApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # validate args - parser = ( - reqparse.RequestParser() - .add_argument("content", type=str, required=True, nullable=False, location="json") - .add_argument("answer", type=str, required=False, nullable=True, location="json") - .add_argument("keywords", type=list, required=False, nullable=True, location="json") - ) - args = parser.parse_args() - SegmentService.segment_create_args_validate(args, document) - segment = SegmentService.create_segment(args, document, dataset) + payload = SegmentCreatePayload.model_validate(console_ns.payload or {}) + payload_dict = payload.model_dump(exclude_none=True) + SegmentService.segment_create_args_validate(payload_dict, document) + segment = SegmentService.create_segment(payload_dict, document, dataset) return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @@ -265,6 +308,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") + @console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__]) def patch(self, dataset_id, document_id, segment_id): current_user, current_tenant_id = current_account_with_tenant() @@ -313,18 +357,12 @@ class DatasetDocumentSegmentUpdateApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # validate args - parser = ( - reqparse.RequestParser() - .add_argument("content", type=str, required=True, nullable=False, location="json") - .add_argument("answer", type=str, required=False, nullable=True, location="json") - .add_argument("keywords", type=list, required=False, nullable=True, location="json") - .add_argument( - "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json" - ) + payload = SegmentUpdatePayload.model_validate(console_ns.payload or {}) + payload_dict = payload.model_dump(exclude_none=True) + SegmentService.segment_create_args_validate(payload_dict, document) + segment = SegmentService.update_segment( + SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset ) - args = parser.parse_args() - SegmentService.segment_create_args_validate(args, document) - segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset) return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @setup_required @@ -377,6 +415,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") + @console_ns.expect(console_ns.models[BatchImportPayload.__name__]) def post(self, dataset_id, document_id): current_user, current_tenant_id = current_account_with_tenant() @@ -391,11 +430,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource): if not document: raise NotFound("Document not found.") - parser = reqparse.RequestParser().add_argument( - "upload_file_id", type=str, required=True, nullable=False, location="json" - ) - args = parser.parse_args() - upload_file_id = args["upload_file_id"] + payload = BatchImportPayload.model_validate(console_ns.payload or {}) + upload_file_id = payload.upload_file_id upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() if not upload_file: @@ -446,6 +482,7 @@ class ChildChunkAddApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") + @console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__]) def post(self, dataset_id, document_id, segment_id): current_user, current_tenant_id = current_account_with_tenant() @@ -491,13 +528,9 @@ class ChildChunkAddApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # validate args - parser = reqparse.RequestParser().add_argument( - "content", type=str, required=True, nullable=False, location="json" - ) - args = parser.parse_args() try: - content = args["content"] - child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset) + payload = ChildChunkCreatePayload.model_validate(console_ns.payload or {}) + child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 @@ -529,18 +562,17 @@ class ChildChunkAddApi(Resource): ) if not segment: raise NotFound("Segment not found.") - parser = ( - reqparse.RequestParser() - .add_argument("limit", type=int, default=20, location="args") - .add_argument("keyword", type=str, default=None, location="args") - .add_argument("page", type=int, default=1, location="args") + args = SegmentListQuery.model_validate( + { + "limit": request.args.get("limit", default=20, type=int), + "keyword": request.args.get("keyword"), + "page": request.args.get("page", default=1, type=int), + } ) - args = parser.parse_args() - - page = args["page"] - limit = min(args["limit"], 100) - keyword = args["keyword"] + page = args.page + limit = min(args.limit, 100) + keyword = args.keyword child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) return { @@ -588,14 +620,9 @@ class ChildChunkAddApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # validate args - parser = reqparse.RequestParser().add_argument( - "chunks", type=list, required=True, nullable=False, location="json" - ) - args = parser.parse_args() + payload = ChildChunkBatchUpdatePayload.model_validate(console_ns.payload or {}) try: - chunks_data = args["chunks"] - chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data] - child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) + child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunks, child_chunk_fields)}, 200 @@ -665,6 +692,7 @@ class ChildChunkUpdateApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") + @console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__]) def patch(self, dataset_id, document_id, segment_id, child_chunk_id): current_user, current_tenant_id = current_account_with_tenant() @@ -711,13 +739,9 @@ class ChildChunkUpdateApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # validate args - parser = reqparse.RequestParser().add_argument( - "content", type=str, required=True, nullable=False, location="json" - ) - args = parser.parse_args() try: - content = args["content"] - child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset) + payload = ChildChunkUpdatePayload.model_validate(console_ns.payload or {}) + child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 950884e496..89c9fcad36 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,8 +1,10 @@ from flask import request -from flask_restx import Resource, fields, marshal, reqparse +from flask_restx import Resource, fields, marshal +from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required @@ -71,10 +73,38 @@ except KeyError: dataset_detail_model = _build_dataset_detail_model() -def _validate_name(name: str) -> str: - if not name or len(name) < 1 or len(name) > 100: - raise ValueError("Name must be between 1 to 100 characters.") - return name +class ExternalKnowledgeApiPayload(BaseModel): + name: str = Field(..., min_length=1, max_length=40) + settings: dict[str, object] + + +class ExternalDatasetCreatePayload(BaseModel): + external_knowledge_api_id: str + external_knowledge_id: str + name: str = Field(..., min_length=1, max_length=40) + description: str | None = Field(None, max_length=400) + external_retrieval_model: dict[str, object] | None = None + + +class ExternalHitTestingPayload(BaseModel): + query: str + external_retrieval_model: dict[str, object] | None = None + metadata_filtering_conditions: dict[str, object] | None = None + + +class BedrockRetrievalPayload(BaseModel): + retrieval_setting: dict[str, object] + query: str + knowledge_id: str + + +register_schema_models( + console_ns, + ExternalKnowledgeApiPayload, + ExternalDatasetCreatePayload, + ExternalHitTestingPayload, + BedrockRetrievalPayload, +) @console_ns.route("/datasets/external-knowledge-api") @@ -113,28 +143,12 @@ class ExternalApiTemplateListApi(Resource): @setup_required @login_required @account_initialization_required + @console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__]) def post(self): current_user, current_tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="Name is required. Name must be between 1 to 100 characters.", - type=_validate_name, - ) - .add_argument( - "settings", - type=dict, - location="json", - nullable=False, - required=True, - ) - ) - args = parser.parse_args() + payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {}) - ExternalDatasetService.validate_api_list(args["settings"]) + ExternalDatasetService.validate_api_list(payload.settings) # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: @@ -142,7 +156,7 @@ class ExternalApiTemplateListApi(Resource): try: external_knowledge_api = ExternalDatasetService.create_external_knowledge_api( - tenant_id=current_tenant_id, user_id=current_user.id, args=args + tenant_id=current_tenant_id, user_id=current_user.id, args=payload.model_dump() ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -171,35 +185,19 @@ class ExternalApiTemplateApi(Resource): @setup_required @login_required @account_initialization_required + @console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__]) def patch(self, external_knowledge_api_id): current_user, current_tenant_id = current_account_with_tenant() external_knowledge_api_id = str(external_knowledge_api_id) - parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="type is required. Name must be between 1 to 100 characters.", - type=_validate_name, - ) - .add_argument( - "settings", - type=dict, - location="json", - nullable=False, - required=True, - ) - ) - args = parser.parse_args() - ExternalDatasetService.validate_api_list(args["settings"]) + payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {}) + ExternalDatasetService.validate_api_list(payload.settings) external_knowledge_api = ExternalDatasetService.update_external_knowledge_api( tenant_id=current_tenant_id, user_id=current_user.id, external_knowledge_api_id=external_knowledge_api_id, - args=args, + args=payload.model_dump(), ) return external_knowledge_api.to_dict(), 200 @@ -240,17 +238,7 @@ class ExternalApiUseCheckApi(Resource): class ExternalDatasetCreateApi(Resource): @console_ns.doc("create_external_dataset") @console_ns.doc(description="Create external knowledge dataset") - @console_ns.expect( - console_ns.model( - "CreateExternalDatasetRequest", - { - "external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"), - "external_knowledge_id": fields.String(required=True, description="External knowledge ID"), - "name": fields.String(required=True, description="Dataset name"), - "description": fields.String(description="Dataset description"), - }, - ) - ) + @console_ns.expect(console_ns.models[ExternalDatasetCreatePayload.__name__]) @console_ns.response(201, "External dataset created successfully", dataset_detail_model) @console_ns.response(400, "Invalid parameters") @console_ns.response(403, "Permission denied") @@ -261,22 +249,8 @@ class ExternalDatasetCreateApi(Resource): def post(self): # The role of the current user in the ta table must be admin, owner, or editor current_user, current_tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json") - .add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json") - .add_argument( - "name", - nullable=False, - required=True, - help="name is required. Name must be between 1 to 100 characters.", - type=_validate_name, - ) - .add_argument("description", type=str, required=False, nullable=True, location="json") - .add_argument("external_retrieval_model", type=dict, required=False, location="json") - ) - - args = parser.parse_args() + payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: @@ -299,16 +273,7 @@ class ExternalKnowledgeHitTestingApi(Resource): @console_ns.doc("test_external_knowledge_retrieval") @console_ns.doc(description="Test external knowledge retrieval for dataset") @console_ns.doc(params={"dataset_id": "Dataset ID"}) - @console_ns.expect( - console_ns.model( - "ExternalHitTestingRequest", - { - "query": fields.String(required=True, description="Query text for testing"), - "retrieval_model": fields.Raw(description="Retrieval model configuration"), - "external_retrieval_model": fields.Raw(description="External retrieval model configuration"), - }, - ) - ) + @console_ns.expect(console_ns.models[ExternalHitTestingPayload.__name__]) @console_ns.response(200, "External hit testing completed successfully") @console_ns.response(404, "Dataset not found") @console_ns.response(400, "Invalid parameters") @@ -327,23 +292,16 @@ class ExternalKnowledgeHitTestingApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - parser = ( - reqparse.RequestParser() - .add_argument("query", type=str, location="json") - .add_argument("external_retrieval_model", type=dict, required=False, location="json") - .add_argument("metadata_filtering_conditions", type=dict, required=False, location="json") - ) - args = parser.parse_args() - - HitTestingService.hit_testing_args_check(args) + payload = ExternalHitTestingPayload.model_validate(console_ns.payload or {}) + HitTestingService.hit_testing_args_check(payload.model_dump()) try: response = HitTestingService.external_retrieve( dataset=dataset, - query=args["query"], + query=payload.query, account=current_user, - external_retrieval_model=args["external_retrieval_model"], - metadata_filtering_conditions=args["metadata_filtering_conditions"], + external_retrieval_model=payload.external_retrieval_model, + metadata_filtering_conditions=payload.metadata_filtering_conditions, ) return response @@ -356,33 +314,13 @@ class BedrockRetrievalApi(Resource): # this api is only for internal testing @console_ns.doc("bedrock_retrieval_test") @console_ns.doc(description="Bedrock retrieval test (internal use only)") - @console_ns.expect( - console_ns.model( - "BedrockRetrievalTestRequest", - { - "retrieval_setting": fields.Raw(required=True, description="Retrieval settings"), - "query": fields.String(required=True, description="Query text"), - "knowledge_id": fields.String(required=True, description="Knowledge ID"), - }, - ) - ) + @console_ns.expect(console_ns.models[BedrockRetrievalPayload.__name__]) @console_ns.response(200, "Bedrock retrieval test completed") def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json") - .add_argument( - "query", - nullable=False, - required=True, - type=str, - ) - .add_argument("knowledge_id", nullable=False, required=True, type=str) - ) - args = parser.parse_args() + payload = BedrockRetrievalPayload.model_validate(console_ns.payload or {}) # Call the knowledge retrieval service result = ExternalDatasetTestService.knowledge_retrieval( - args["retrieval_setting"], args["query"], args["knowledge_id"] + payload.retrieval_setting, payload.query, payload.knowledge_id ) return result, 200 diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 7ba2eeb7dd..932cb4fcce 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,13 +1,17 @@ -from flask_restx import Resource, fields +from flask_restx import Resource -from controllers.console import console_ns -from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase -from controllers.console.wraps import ( +from controllers.common.schema import register_schema_model +from libs.login import login_required + +from .. import console_ns +from ..datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload +from ..wraps import ( account_initialization_required, cloud_edition_billing_rate_limit_check, setup_required, ) -from libs.login import login_required + +register_schema_model(console_ns, HitTestingPayload) @console_ns.route("/datasets//hit-testing") @@ -15,17 +19,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): @console_ns.doc("test_dataset_retrieval") @console_ns.doc(description="Test dataset knowledge retrieval") @console_ns.doc(params={"dataset_id": "Dataset ID"}) - @console_ns.expect( - console_ns.model( - "HitTestingRequest", - { - "query": fields.String(required=True, description="Query text for testing"), - "retrieval_model": fields.Raw(description="Retrieval model configuration"), - "top_k": fields.Integer(description="Number of top results to return"), - "score_threshold": fields.Float(description="Score threshold for filtering results"), - }, - ) - ) + @console_ns.expect(console_ns.models[HitTestingPayload.__name__]) @console_ns.response(200, "Hit testing completed successfully") @console_ns.response(404, "Dataset not found") @console_ns.response(400, "Invalid parameters") @@ -37,7 +31,8 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): dataset_id_str = str(dataset_id) dataset = self.get_and_validate_dataset(dataset_id_str) - args = self.parse_args() + payload = HitTestingPayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) self.hit_testing_args_check(args) return self.perform_hit_testing(dataset, args) diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 99d4d5a29c..fac90a0135 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,6 +1,8 @@ import logging +from typing import Any -from flask_restx import marshal, reqparse +from flask_restx import marshal +from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -27,6 +29,12 @@ from services.hit_testing_service import HitTestingService logger = logging.getLogger(__name__) +class HitTestingPayload(BaseModel): + query: str = Field(max_length=250) + retrieval_model: dict[str, Any] | None = None + external_retrieval_model: dict[str, Any] | None = None + + class DatasetsHitTestingBase: @staticmethod def get_and_validate_dataset(dataset_id: str): @@ -43,19 +51,9 @@ class DatasetsHitTestingBase: return dataset @staticmethod - def hit_testing_args_check(args): + def hit_testing_args_check(args: dict[str, Any]): HitTestingService.hit_testing_args_check(args) - @staticmethod - def parse_args(): - parser = ( - reqparse.RequestParser() - .add_argument("query", type=str, location="json") - .add_argument("retrieval_model", type=dict, required=False, location="json") - .add_argument("external_retrieval_model", type=dict, required=False, location="json") - ) - return parser.parse_args() - @staticmethod def perform_hit_testing(dataset, args): assert isinstance(current_user, Account) diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 72b2ff0ff8..8eead1696a 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -1,8 +1,10 @@ from typing import Literal -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_model, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from fields.dataset_fields import dataset_metadata_fields @@ -15,6 +17,14 @@ from services.entities.knowledge_entities.knowledge_entities import ( from services.metadata_service import MetadataService +class MetadataUpdatePayload(BaseModel): + name: str + + +register_schema_models(console_ns, MetadataArgs, MetadataOperationData) +register_schema_model(console_ns, MetadataUpdatePayload) + + @console_ns.route("/datasets//metadata") class DatasetMetadataCreateApi(Resource): @setup_required @@ -22,15 +32,10 @@ class DatasetMetadataCreateApi(Resource): @account_initialization_required @enterprise_license_required @marshal_with(dataset_metadata_fields) + @console_ns.expect(console_ns.models[MetadataArgs.__name__]) def post(self, dataset_id): current_user, _ = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - ) - args = parser.parse_args() - metadata_args = MetadataArgs.model_validate(args) + metadata_args = MetadataArgs.model_validate(console_ns.payload or {}) dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -60,11 +65,11 @@ class DatasetMetadataApi(Resource): @account_initialization_required @enterprise_license_required @marshal_with(dataset_metadata_fields) + @console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__]) def patch(self, dataset_id, metadata_id): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() - name = args["name"] + payload = MetadataUpdatePayload.model_validate(console_ns.payload or {}) + name = payload.name dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) @@ -131,6 +136,7 @@ class DocumentMetadataEditApi(Resource): @login_required @account_initialization_required @enterprise_license_required + @console_ns.expect(console_ns.models[MetadataOperationData.__name__]) def post(self, dataset_id): current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) @@ -139,11 +145,7 @@ class DocumentMetadataEditApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - parser = reqparse.RequestParser().add_argument( - "operation_data", type=list, required=True, nullable=False, location="json" - ) - args = parser.parse_args() - metadata_args = MetadataOperationData.model_validate(args) + metadata_args = MetadataOperationData.model_validate(console_ns.payload or {}) MetadataService.update_documents_metadata(dataset, metadata_args) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index cf9e5d2990..1a47e226e5 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -1,20 +1,63 @@ +from typing import Any + from flask import make_response, redirect, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.oauth import OAuthHandler -from libs.helper import StrLen from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService +class DatasourceCredentialPayload(BaseModel): + name: str | None = Field(default=None, max_length=100) + credentials: dict[str, Any] + + +class DatasourceCredentialDeletePayload(BaseModel): + credential_id: str + + +class DatasourceCredentialUpdatePayload(BaseModel): + credential_id: str + name: str | None = Field(default=None, max_length=100) + credentials: dict[str, Any] | None = None + + +class DatasourceCustomClientPayload(BaseModel): + client_params: dict[str, Any] | None = None + enable_oauth_custom_client: bool | None = None + + +class DatasourceDefaultPayload(BaseModel): + id: str + + +class DatasourceUpdateNamePayload(BaseModel): + credential_id: str + name: str = Field(max_length=100) + + +register_schema_models( + console_ns, + DatasourceCredentialPayload, + DatasourceCredentialDeletePayload, + DatasourceCredentialUpdatePayload, + DatasourceCustomClientPayload, + DatasourceDefaultPayload, + DatasourceUpdateNamePayload, +) + + @console_ns.route("/oauth/plugin//datasource/get-authorization-url") class DatasourcePluginOAuthAuthorizationUrl(Resource): @setup_required @@ -121,16 +164,9 @@ class DatasourceOAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") -parser_datasource = ( - reqparse.RequestParser() - .add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None) - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") -) - - @console_ns.route("/auth/plugin/datasource/") class DatasourceAuth(Resource): - @console_ns.expect(parser_datasource) + @console_ns.expect(console_ns.models[DatasourceCredentialPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -138,7 +174,7 @@ class DatasourceAuth(Resource): def post(self, provider_id: str): _, current_tenant_id = current_account_with_tenant() - args = parser_datasource.parse_args() + payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {}) datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() @@ -146,8 +182,8 @@ class DatasourceAuth(Resource): datasource_provider_service.add_datasource_api_key_provider( tenant_id=current_tenant_id, provider_id=datasource_provider_id, - credentials=args["credentials"], - name=args["name"], + credentials=payload.credentials, + name=payload.name, ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) @@ -169,14 +205,9 @@ class DatasourceAuth(Resource): return {"result": datasources}, 200 -parser_datasource_delete = reqparse.RequestParser().add_argument( - "credential_id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/auth/plugin/datasource//delete") class DatasourceAuthDeleteApi(Resource): - @console_ns.expect(parser_datasource_delete) + @console_ns.expect(console_ns.models[DatasourceCredentialDeletePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -188,28 +219,20 @@ class DatasourceAuthDeleteApi(Resource): plugin_id = datasource_provider_id.plugin_id provider_name = datasource_provider_id.provider_name - args = parser_datasource_delete.parse_args() + payload = DatasourceCredentialDeletePayload.model_validate(console_ns.payload or {}) datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_datasource_credentials( tenant_id=current_tenant_id, - auth_id=args["credential_id"], + auth_id=payload.credential_id, provider=provider_name, plugin_id=plugin_id, ) return {"result": "success"}, 200 -parser_datasource_update = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=False, nullable=True, location="json") - .add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json") - .add_argument("credential_id", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/auth/plugin/datasource//update") class DatasourceAuthUpdateApi(Resource): - @console_ns.expect(parser_datasource_update) + @console_ns.expect(console_ns.models[DatasourceCredentialUpdatePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -218,16 +241,16 @@ class DatasourceAuthUpdateApi(Resource): _, current_tenant_id = current_account_with_tenant() datasource_provider_id = DatasourceProviderID(provider_id) - args = parser_datasource_update.parse_args() + payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {}) datasource_provider_service = DatasourceProviderService() datasource_provider_service.update_datasource_credentials( tenant_id=current_tenant_id, - auth_id=args["credential_id"], + auth_id=payload.credential_id, provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id, - credentials=args.get("credentials", {}), - name=args.get("name", None), + credentials=payload.credentials or {}, + name=payload.name, ) return {"result": "success"}, 201 @@ -258,16 +281,9 @@ class DatasourceHardCodeAuthListApi(Resource): return {"result": jsonable_encoder(datasources)}, 200 -parser_datasource_custom = ( - reqparse.RequestParser() - .add_argument("client_params", type=dict, required=False, nullable=True, location="json") - .add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") -) - - @console_ns.route("/auth/plugin/datasource//custom-client") class DatasourceAuthOauthCustomClient(Resource): - @console_ns.expect(parser_datasource_custom) + @console_ns.expect(console_ns.models[DatasourceCustomClientPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -275,14 +291,14 @@ class DatasourceAuthOauthCustomClient(Resource): def post(self, provider_id: str): _, current_tenant_id = current_account_with_tenant() - args = parser_datasource_custom.parse_args() + payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {}) datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.setup_oauth_custom_client_params( tenant_id=current_tenant_id, datasource_provider_id=datasource_provider_id, - client_params=args.get("client_params", {}), - enabled=args.get("enable_oauth_custom_client", False), + client_params=payload.client_params or {}, + enabled=payload.enable_oauth_custom_client or False, ) return {"result": "success"}, 200 @@ -301,12 +317,9 @@ class DatasourceAuthOauthCustomClient(Resource): return {"result": "success"}, 200 -parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json") - - @console_ns.route("/auth/plugin/datasource//default") class DatasourceAuthDefaultApi(Resource): - @console_ns.expect(parser_default) + @console_ns.expect(console_ns.models[DatasourceDefaultPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -314,27 +327,20 @@ class DatasourceAuthDefaultApi(Resource): def post(self, provider_id: str): _, current_tenant_id = current_account_with_tenant() - args = parser_default.parse_args() + payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {}) datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.set_default_datasource_provider( tenant_id=current_tenant_id, datasource_provider_id=datasource_provider_id, - credential_id=args["id"], + credential_id=payload.id, ) return {"result": "success"}, 200 -parser_update_name = ( - reqparse.RequestParser() - .add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json") - .add_argument("credential_id", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/auth/plugin/datasource//update-name") class DatasourceUpdateProviderNameApi(Resource): - @console_ns.expect(parser_update_name) + @console_ns.expect(console_ns.models[DatasourceUpdateNamePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -342,13 +348,13 @@ class DatasourceUpdateProviderNameApi(Resource): def post(self, provider_id: str): _, current_tenant_id = current_account_with_tenant() - args = parser_update_name.parse_args() + payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {}) datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.update_datasource_provider_name( tenant_id=current_tenant_id, datasource_provider_id=datasource_provider_id, - name=args["name"], - credential_id=args["credential_id"], + name=payload.name, + credential_id=payload.credential_id, ) return {"result": "success"}, 200 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index f589bba3bf..6e0cd31b8d 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -1,9 +1,11 @@ import logging from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field from sqlalchemy.orm import Session +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, @@ -20,18 +22,6 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService logger = logging.getLogger(__name__) -def _validate_name(name: str) -> str: - if not name or len(name) < 1 or len(name) > 40: - raise ValueError("Name must be between 1 to 40 characters.") - return name - - -def _validate_description_length(description: str) -> str: - if len(description) > 400: - raise ValueError("Description cannot exceed 400 characters.") - return description - - @console_ns.route("/rag/pipeline/templates") class PipelineTemplateListApi(Resource): @setup_required @@ -59,6 +49,15 @@ class PipelineTemplateDetailApi(Resource): return pipeline_template, 200 +class Payload(BaseModel): + name: str = Field(..., min_length=1, max_length=40) + description: str = Field(default="", max_length=400) + icon_info: dict[str, object] | None = None + + +register_schema_models(console_ns, Payload) + + @console_ns.route("/rag/pipeline/customized/templates/") class CustomizedPipelineTemplateApi(Resource): @setup_required @@ -66,31 +65,8 @@ class CustomizedPipelineTemplateApi(Resource): @account_initialization_required @enterprise_license_required def patch(self, template_id: str): - parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 40 characters.", - type=_validate_name, - ) - .add_argument( - "description", - type=_validate_description_length, - nullable=True, - required=False, - default="", - ) - .add_argument( - "icon_info", - type=dict, - location="json", - nullable=True, - ) - ) - args = parser.parse_args() - pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args) + payload = Payload.model_validate(console_ns.payload or {}) + pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump()) RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) return 200 @@ -119,36 +95,14 @@ class CustomizedPipelineTemplateApi(Resource): @console_ns.route("/rag/pipelines//customized/publish") class PublishCustomizedPipelineTemplateApi(Resource): + @console_ns.expect(console_ns.models[Payload.__name__]) @setup_required @login_required @account_initialization_required @enterprise_license_required @knowledge_pipeline_publish_enabled def post(self, pipeline_id: str): - parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 40 characters.", - type=_validate_name, - ) - .add_argument( - "description", - type=_validate_description_length, - nullable=True, - required=False, - default="", - ) - .add_argument( - "icon_info", - type=dict, - location="json", - nullable=True, - ) - ) - args = parser.parse_args() + payload = Payload.model_validate(console_ns.payload or {}) rag_pipeline_service = RagPipelineService() - rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args) + rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump()) return {"result": "success"} diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index 98876e9f5e..e65cb19b39 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -1,8 +1,10 @@ -from flask_restx import Resource, marshal, reqparse +from flask_restx import Resource, marshal +from pydantic import BaseModel from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden import services +from controllers.common.schema import register_schema_model from controllers.console import console_ns from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.wraps import ( @@ -19,22 +21,22 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService +class RagPipelineDatasetImportPayload(BaseModel): + yaml_content: str + + +register_schema_model(console_ns, RagPipelineDatasetImportPayload) + + @console_ns.route("/rag/pipeline/dataset") class CreateRagPipelineDatasetApi(Resource): + @console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__]) @setup_required @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") def post(self): - parser = reqparse.RequestParser().add_argument( - "yaml_content", - type=str, - nullable=False, - required=True, - help="yaml_content is required.", - ) - - args = parser.parse_args() + payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {}) current_user, current_tenant_id = current_account_with_tenant() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: @@ -49,7 +51,7 @@ class CreateRagPipelineDatasetApi(Resource): ), permission=DatasetPermissionEnum.ONLY_ME, partial_member_list=None, - yaml_content=args["yaml_content"], + yaml_content=payload.yaml_content, ) try: with Session(db.engine) as session: diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 858ba94bf8..720e2ce365 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -1,11 +1,13 @@ import logging -from typing import NoReturn +from typing import Any, NoReturn -from flask import Response -from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse +from flask import Response, request +from flask_restx import Resource, fields, marshal, marshal_with +from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( DraftWorkflowNotExist, @@ -33,19 +35,21 @@ logger = logging.getLogger(__name__) def _create_pagination_parser(): - parser = ( - reqparse.RequestParser() - .add_argument( - "page", - type=inputs.int_range(1, 100_000), - required=False, - default=1, - location="args", - help="the page of data requested", - ) - .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") - ) - return parser + class PaginationQuery(BaseModel): + page: int = Field(default=1, ge=1, le=100_000) + limit: int = Field(default=20, ge=1, le=100) + + register_schema_models(console_ns, PaginationQuery) + + return PaginationQuery + + +class WorkflowDraftVariablePatchPayload(BaseModel): + name: str | None = None + value: Any | None = None + + +register_schema_models(console_ns, WorkflowDraftVariablePatchPayload) def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]: @@ -93,8 +97,8 @@ class RagPipelineVariableCollectionApi(Resource): """ Get draft workflow """ - parser = _create_pagination_parser() - args = parser.parse_args() + pagination = _create_pagination_parser() + query = pagination.model_validate(request.args.to_dict()) # fetch draft workflow by app_model rag_pipeline_service = RagPipelineService() @@ -109,8 +113,8 @@ class RagPipelineVariableCollectionApi(Resource): ) workflow_vars = draft_var_srv.list_variables_without_values( app_id=pipeline.id, - page=args.page, - limit=args.limit, + page=query.page, + limit=query.limit, ) return workflow_vars @@ -186,6 +190,7 @@ class RagPipelineVariableApi(Resource): @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) + @console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__]) def patch(self, pipeline: Pipeline, variable_id: str): # Request payload for file types: # @@ -208,16 +213,11 @@ class RagPipelineVariableApi(Resource): # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # } - parser = ( - reqparse.RequestParser() - .add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") - .add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") - ) - draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) - args = parser.parse_args(strict=True) + payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) variable = draft_var_srv.get_variable(variable_id=variable_id) if variable is None: diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index d658d65b71..d43ee9a6e0 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -1,6 +1,9 @@ -from flask_restx import Resource, marshal_with, reqparse # type: ignore +from flask import request +from flask_restx import Resource, marshal_with # type: ignore +from pydantic import BaseModel, Field from sqlalchemy.orm import Session +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import ( @@ -16,6 +19,25 @@ from services.app_dsl_service import ImportStatus from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService +class RagPipelineImportPayload(BaseModel): + mode: str + yaml_content: str | None = None + yaml_url: str | None = None + name: str | None = None + description: str | None = None + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + pipeline_id: str | None = None + + +class IncludeSecretQuery(BaseModel): + include_secret: str = Field(default="false") + + +register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery) + + @console_ns.route("/rag/pipelines/imports") class RagPipelineImportApi(Resource): @setup_required @@ -23,23 +45,11 @@ class RagPipelineImportApi(Resource): @account_initialization_required @edit_permission_required @marshal_with(pipeline_import_fields) + @console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__]) def post(self): # Check user role first current_user, _ = current_account_with_tenant() - - parser = ( - reqparse.RequestParser() - .add_argument("mode", type=str, required=True, location="json") - .add_argument("yaml_content", type=str, location="json") - .add_argument("yaml_url", type=str, location="json") - .add_argument("name", type=str, location="json") - .add_argument("description", type=str, location="json") - .add_argument("icon_type", type=str, location="json") - .add_argument("icon", type=str, location="json") - .add_argument("icon_background", type=str, location="json") - .add_argument("pipeline_id", type=str, location="json") - ) - args = parser.parse_args() + payload = RagPipelineImportPayload.model_validate(console_ns.payload or {}) # Create service with session with Session(db.engine) as session: @@ -48,11 +58,11 @@ class RagPipelineImportApi(Resource): account = current_user result = import_service.import_rag_pipeline( account=account, - import_mode=args["mode"], - yaml_content=args.get("yaml_content"), - yaml_url=args.get("yaml_url"), - pipeline_id=args.get("pipeline_id"), - dataset_name=args.get("name"), + import_mode=payload.mode, + yaml_content=payload.yaml_content, + yaml_url=payload.yaml_url, + pipeline_id=payload.pipeline_id, + dataset_name=payload.name, ) session.commit() @@ -114,13 +124,12 @@ class RagPipelineExportApi(Resource): @edit_permission_required def get(self, pipeline: Pipeline): # Add include_secret params - parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args") - args = parser.parse_args() + query = IncludeSecretQuery.model_validate(request.args.to_dict()) with Session(db.engine) as session: export_service = RagPipelineDslService(session) result = export_service.export_rag_pipeline_dsl( - pipeline=pipeline, include_secret=args["include_secret"] == "true" + pipeline=pipeline, include_secret=query.include_secret == "true" ) return {"data": result}, 200 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index a0dc692c4e..debe8eed97 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -1,14 +1,16 @@ import json import logging -from typing import cast +from typing import Any, Literal, cast +from uuid import UUID from flask import abort, request -from flask_restx import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore -from flask_restx.inputs import int_range # type: ignore +from flask_restx import Resource, marshal_with # type: ignore +from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( ConversationCompletedError, @@ -36,7 +38,7 @@ from fields.workflow_run_fields import ( workflow_run_pagination_fields, ) from libs import helper -from libs.helper import TimestampField, uuid_value +from libs.helper import TimestampField from libs.login import current_account_with_tenant, current_user, login_required from models import Account from models.dataset import Pipeline @@ -51,6 +53,91 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran logger = logging.getLogger(__name__) +class DraftWorkflowSyncPayload(BaseModel): + graph: dict[str, Any] + hash: str | None = None + environment_variables: list[dict[str, Any]] | None = None + conversation_variables: list[dict[str, Any]] | None = None + rag_pipeline_variables: list[dict[str, Any]] | None = None + features: dict[str, Any] | None = None + + +class NodeRunPayload(BaseModel): + inputs: dict[str, Any] | None = None + + +class NodeRunRequiredPayload(BaseModel): + inputs: dict[str, Any] + + +class DatasourceNodeRunPayload(BaseModel): + inputs: dict[str, Any] + datasource_type: str + credential_id: str | None = None + + +class DraftWorkflowRunPayload(BaseModel): + inputs: dict[str, Any] + datasource_type: str + datasource_info_list: list[dict[str, Any]] + start_node_id: str + + +class PublishedWorkflowRunPayload(DraftWorkflowRunPayload): + is_preview: bool = False + response_mode: Literal["streaming", "blocking"] = "streaming" + original_document_id: str | None = None + + +class DefaultBlockConfigQuery(BaseModel): + q: str | None = None + + +class WorkflowListQuery(BaseModel): + page: int = Field(default=1, ge=1, le=99999) + limit: int = Field(default=10, ge=1, le=100) + user_id: str | None = None + named_only: bool = False + + +class WorkflowUpdatePayload(BaseModel): + marked_name: str | None = Field(default=None, max_length=20) + marked_comment: str | None = Field(default=None, max_length=100) + + +class NodeIdQuery(BaseModel): + node_id: str + + +class WorkflowRunQuery(BaseModel): + last_id: UUID | None = None + limit: int = Field(default=20, ge=1, le=100) + + +class DatasourceVariablesPayload(BaseModel): + datasource_type: str + datasource_info: dict[str, Any] + start_node_id: str + start_node_title: str + + +register_schema_models( + console_ns, + DraftWorkflowSyncPayload, + NodeRunPayload, + NodeRunRequiredPayload, + DatasourceNodeRunPayload, + DraftWorkflowRunPayload, + PublishedWorkflowRunPayload, + DefaultBlockConfigQuery, + WorkflowListQuery, + WorkflowUpdatePayload, + NodeIdQuery, + WorkflowRunQuery, + DatasourceVariablesPayload, +) + + @console_ns.route("/rag/pipelines//workflows/draft") class DraftRagPipelineApi(Resource): @setup_required @@ -88,15 +175,7 @@ class DraftRagPipelineApi(Resource): content_type = request.headers.get("Content-Type", "") if "application/json" in content_type: - parser = ( - reqparse.RequestParser() - .add_argument("graph", type=dict, required=True, nullable=False, location="json") - .add_argument("hash", type=str, required=False, location="json") - .add_argument("environment_variables", type=list, required=False, location="json") - .add_argument("conversation_variables", type=list, required=False, location="json") - .add_argument("rag_pipeline_variables", type=list, required=False, location="json") - ) - args = parser.parse_args() + payload_dict = console_ns.payload or {} elif "text/plain" in content_type: try: data = json.loads(request.data.decode("utf-8")) @@ -106,7 +185,7 @@ class DraftRagPipelineApi(Resource): if not isinstance(data.get("graph"), dict): raise ValueError("graph is not a dict") - args = { + payload_dict = { "graph": data.get("graph"), "features": data.get("features"), "hash": data.get("hash"), @@ -119,24 +198,26 @@ class DraftRagPipelineApi(Resource): else: abort(415) + payload = DraftWorkflowSyncPayload.model_validate(payload_dict) + try: - environment_variables_list = args.get("environment_variables") or [] + environment_variables_list = payload.environment_variables or [] environment_variables = [ variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list ] - conversation_variables_list = args.get("conversation_variables") or [] + conversation_variables_list = payload.conversation_variables or [] conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, - graph=args["graph"], - unique_hash=args.get("hash"), + graph=payload.graph, + unique_hash=payload.hash, account=current_user, environment_variables=environment_variables, conversation_variables=conversation_variables, - rag_pipeline_variables=args.get("rag_pipeline_variables") or [], + rag_pipeline_variables=payload.rag_pipeline_variables or [], ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() @@ -148,12 +229,9 @@ class DraftRagPipelineApi(Resource): } -parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") - - @console_ns.route("/rag/pipelines//workflows/draft/iteration/nodes//run") class RagPipelineDraftRunIterationNodeApi(Resource): - @console_ns.expect(parser_run) + @console_ns.expect(console_ns.models[NodeRunPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -166,7 +244,8 @@ class RagPipelineDraftRunIterationNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - args = parser_run.parse_args() + payload = NodeRunPayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) try: response = PipelineGenerateService.generate_single_iteration( @@ -187,7 +266,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/loop/nodes//run") class RagPipelineDraftRunLoopNodeApi(Resource): - @console_ns.expect(parser_run) + @console_ns.expect(console_ns.models[NodeRunPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -200,7 +279,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - args = parser_run.parse_args() + payload = NodeRunPayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) try: response = PipelineGenerateService.generate_single_loop( @@ -219,18 +299,9 @@ class RagPipelineDraftRunLoopNodeApi(Resource): raise InternalServerError() -parser_draft_run = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("datasource_type", type=str, required=True, location="json") - .add_argument("datasource_info_list", type=list, required=True, location="json") - .add_argument("start_node_id", type=str, required=True, location="json") -) - - @console_ns.route("/rag/pipelines//workflows/draft/run") class DraftRagPipelineRunApi(Resource): - @console_ns.expect(parser_draft_run) + @console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -243,7 +314,8 @@ class DraftRagPipelineRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - args = parser_draft_run.parse_args() + payload = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}) + args = payload.model_dump() try: response = PipelineGenerateService.generate( @@ -259,21 +331,9 @@ class DraftRagPipelineRunApi(Resource): raise InvokeRateLimitHttpError(ex.description) -parser_published_run = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("datasource_type", type=str, required=True, location="json") - .add_argument("datasource_info_list", type=list, required=True, location="json") - .add_argument("start_node_id", type=str, required=True, location="json") - .add_argument("is_preview", type=bool, required=True, location="json", default=False) - .add_argument("response_mode", type=str, required=True, location="json", default="streaming") - .add_argument("original_document_id", type=str, required=False, location="json") -) - - @console_ns.route("/rag/pipelines//workflows/published/run") class PublishedRagPipelineRunApi(Resource): - @console_ns.expect(parser_published_run) + @console_ns.expect(console_ns.models[PublishedWorkflowRunPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -286,16 +346,16 @@ class PublishedRagPipelineRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - args = parser_published_run.parse_args() - - streaming = args["response_mode"] == "streaming" + payload = PublishedWorkflowRunPayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) + streaming = payload.response_mode == "streaming" try: response = PipelineGenerateService.generate( pipeline=pipeline, user=current_user, args=args, - invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED, streaming=streaming, ) @@ -387,17 +447,9 @@ class PublishedRagPipelineRunApi(Resource): # # return result # -parser_rag_run = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("datasource_type", type=str, required=True, location="json") - .add_argument("credential_id", type=str, required=False, location="json") -) - - @console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//run") class RagPipelinePublishedDatasourceNodeRunApi(Resource): - @console_ns.expect(parser_rag_run) + @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -410,14 +462,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - args = parser_rag_run.parse_args() - - inputs = args.get("inputs") - if inputs is None: - raise ValueError("missing inputs") - datasource_type = args.get("datasource_type") - if datasource_type is None: - raise ValueError("missing datasource_type") + payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {}) rag_pipeline_service = RagPipelineService() return helper.compact_generate_response( @@ -425,11 +470,11 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): rag_pipeline_service.run_datasource_workflow_node( pipeline=pipeline, node_id=node_id, - user_inputs=inputs, + user_inputs=payload.inputs, account=current_user, - datasource_type=datasource_type, + datasource_type=payload.datasource_type, is_published=False, - credential_id=args.get("credential_id"), + credential_id=payload.credential_id, ) ) ) @@ -437,7 +482,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/datasource/nodes//run") class RagPipelineDraftDatasourceNodeRunApi(Resource): - @console_ns.expect(parser_rag_run) + @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__]) @setup_required @login_required @edit_permission_required @@ -450,14 +495,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - args = parser_rag_run.parse_args() - - inputs = args.get("inputs") - if inputs is None: - raise ValueError("missing inputs") - datasource_type = args.get("datasource_type") - if datasource_type is None: - raise ValueError("missing datasource_type") + payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {}) rag_pipeline_service = RagPipelineService() return helper.compact_generate_response( @@ -465,24 +503,19 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): rag_pipeline_service.run_datasource_workflow_node( pipeline=pipeline, node_id=node_id, - user_inputs=inputs, + user_inputs=payload.inputs, account=current_user, - datasource_type=datasource_type, + datasource_type=payload.datasource_type, is_published=False, - credential_id=args.get("credential_id"), + credential_id=payload.credential_id, ) ) ) -parser_run_api = reqparse.RequestParser().add_argument( - "inputs", type=dict, required=True, nullable=False, location="json" -) - - @console_ns.route("/rag/pipelines//workflows/draft/nodes//run") class RagPipelineDraftNodeRunApi(Resource): - @console_ns.expect(parser_run_api) + @console_ns.expect(console_ns.models[NodeRunRequiredPayload.__name__]) @setup_required @login_required @edit_permission_required @@ -496,11 +529,8 @@ class RagPipelineDraftNodeRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - args = parser_run_api.parse_args() - - inputs = args.get("inputs") - if inputs == None: - raise ValueError("missing inputs") + payload = NodeRunRequiredPayload.model_validate(console_ns.payload or {}) + inputs = payload.inputs rag_pipeline_service = RagPipelineService() workflow_node_execution = rag_pipeline_service.run_draft_workflow_node( @@ -602,12 +632,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource): return rag_pipeline_service.get_default_block_configs() -parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args") - - @console_ns.route("/rag/pipelines//workflows/default-workflow-block-configs/") class DefaultRagPipelineBlockConfigApi(Resource): - @console_ns.expect(parser_default) @setup_required @login_required @account_initialization_required @@ -617,14 +643,12 @@ class DefaultRagPipelineBlockConfigApi(Resource): """ Get default block config """ - args = parser_default.parse_args() - - q = args.get("q") + query = DefaultBlockConfigQuery.model_validate(request.args.to_dict()) filters = None - if q: + if query.q: try: - filters = json.loads(args.get("q", "")) + filters = json.loads(query.q) except json.JSONDecodeError: raise ValueError("Invalid filters") @@ -633,18 +657,8 @@ class DefaultRagPipelineBlockConfigApi(Resource): return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) -parser_wf = ( - reqparse.RequestParser() - .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") - .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args") - .add_argument("user_id", type=str, required=False, location="args") - .add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args") -) - - @console_ns.route("/rag/pipelines//workflows") class PublishedAllRagPipelineApi(Resource): - @console_ns.expect(parser_wf) @setup_required @login_required @account_initialization_required @@ -657,16 +671,16 @@ class PublishedAllRagPipelineApi(Resource): """ current_user, _ = current_account_with_tenant() - args = parser_wf.parse_args() - page = args["page"] - limit = args["limit"] - user_id = args.get("user_id") - named_only = args.get("named_only", False) + query = WorkflowListQuery.model_validate(request.args.to_dict()) + + page = query.page + limit = query.limit + user_id = query.user_id + named_only = query.named_only if user_id: if user_id != current_user.id: raise Forbidden() - user_id = cast(str, user_id) rag_pipeline_service = RagPipelineService() with Session(db.engine) as session: @@ -687,16 +701,8 @@ class PublishedAllRagPipelineApi(Resource): } -parser_wf_id = ( - reqparse.RequestParser() - .add_argument("marked_name", type=str, required=False, location="json") - .add_argument("marked_comment", type=str, required=False, location="json") -) - - @console_ns.route("/rag/pipelines//workflows/") class RagPipelineByIdApi(Resource): - @console_ns.expect(parser_wf_id) @setup_required @login_required @account_initialization_required @@ -710,20 +716,8 @@ class RagPipelineByIdApi(Resource): # Check permission current_user, _ = current_account_with_tenant() - args = parser_wf_id.parse_args() - - # Validate name and comment length - if args.marked_name and len(args.marked_name) > 20: - raise ValueError("Marked name cannot exceed 20 characters") - if args.marked_comment and len(args.marked_comment) > 100: - raise ValueError("Marked comment cannot exceed 100 characters") - - # Prepare update data - update_data = {} - if args.get("marked_name") is not None: - update_data["marked_name"] = args["marked_name"] - if args.get("marked_comment") is not None: - update_data["marked_comment"] = args["marked_comment"] + payload = WorkflowUpdatePayload.model_validate(console_ns.payload or {}) + update_data = payload.model_dump(exclude_unset=True) if not update_data: return {"message": "No valid fields to update"}, 400 @@ -749,12 +743,8 @@ class RagPipelineByIdApi(Resource): return workflow -parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args") - - @console_ns.route("/rag/pipelines//workflows/published/processing/parameters") class PublishedRagPipelineSecondStepApi(Resource): - @console_ns.expect(parser_parameters) @setup_required @login_required @account_initialization_required @@ -764,10 +754,8 @@ class PublishedRagPipelineSecondStepApi(Resource): """ Get second step parameters of rag pipeline """ - args = parser_parameters.parse_args() - node_id = args.get("node_id") - if not node_id: - raise ValueError("Node ID is required") + query = NodeIdQuery.model_validate(request.args.to_dict()) + node_id = query.node_id rag_pipeline_service = RagPipelineService() variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) return { @@ -777,7 +765,6 @@ class PublishedRagPipelineSecondStepApi(Resource): @console_ns.route("/rag/pipelines//workflows/published/pre-processing/parameters") class PublishedRagPipelineFirstStepApi(Resource): - @console_ns.expect(parser_parameters) @setup_required @login_required @account_initialization_required @@ -787,10 +774,8 @@ class PublishedRagPipelineFirstStepApi(Resource): """ Get first step parameters of rag pipeline """ - args = parser_parameters.parse_args() - node_id = args.get("node_id") - if not node_id: - raise ValueError("Node ID is required") + query = NodeIdQuery.model_validate(request.args.to_dict()) + node_id = query.node_id rag_pipeline_service = RagPipelineService() variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) return { @@ -800,7 +785,6 @@ class PublishedRagPipelineFirstStepApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/pre-processing/parameters") class DraftRagPipelineFirstStepApi(Resource): - @console_ns.expect(parser_parameters) @setup_required @login_required @account_initialization_required @@ -810,10 +794,8 @@ class DraftRagPipelineFirstStepApi(Resource): """ Get first step parameters of rag pipeline """ - args = parser_parameters.parse_args() - node_id = args.get("node_id") - if not node_id: - raise ValueError("Node ID is required") + query = NodeIdQuery.model_validate(request.args.to_dict()) + node_id = query.node_id rag_pipeline_service = RagPipelineService() variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) return { @@ -823,7 +805,6 @@ class DraftRagPipelineFirstStepApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/processing/parameters") class DraftRagPipelineSecondStepApi(Resource): - @console_ns.expect(parser_parameters) @setup_required @login_required @account_initialization_required @@ -833,10 +814,8 @@ class DraftRagPipelineSecondStepApi(Resource): """ Get second step parameters of rag pipeline """ - args = parser_parameters.parse_args() - node_id = args.get("node_id") - if not node_id: - raise ValueError("Node ID is required") + query = NodeIdQuery.model_validate(request.args.to_dict()) + node_id = query.node_id rag_pipeline_service = RagPipelineService() variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) @@ -845,16 +824,8 @@ class DraftRagPipelineSecondStepApi(Resource): } -parser_wf_run = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") -) - - @console_ns.route("/rag/pipelines//workflow-runs") class RagPipelineWorkflowRunListApi(Resource): - @console_ns.expect(parser_wf_run) @setup_required @login_required @account_initialization_required @@ -864,7 +835,16 @@ class RagPipelineWorkflowRunListApi(Resource): """ Get workflow run list """ - args = parser_wf_run.parse_args() + query = WorkflowRunQuery.model_validate( + { + "last_id": request.args.get("last_id"), + "limit": request.args.get("limit", type=int, default=20), + } + ) + args = { + "last_id": str(query.last_id) if query.last_id else None, + "limit": query.limit, + } rag_pipeline_service = RagPipelineService() result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args) @@ -964,18 +944,9 @@ class RagPipelineTransformApi(Resource): return result -parser_var = ( - reqparse.RequestParser() - .add_argument("datasource_type", type=str, required=True, location="json") - .add_argument("datasource_info", type=dict, required=True, location="json") - .add_argument("start_node_id", type=str, required=True, location="json") - .add_argument("start_node_title", type=str, required=True, location="json") -) - - @console_ns.route("/rag/pipelines//workflows/draft/datasource/variables-inspect") class RagPipelineDatasourceVariableApi(Resource): - @console_ns.expect(parser_var) + @console_ns.expect(console_ns.models[DatasourceVariablesPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -987,7 +958,7 @@ class RagPipelineDatasourceVariableApi(Resource): Set datasource variables """ current_user, _ = current_account_with_tenant() - args = parser_var.parse_args() + args = DatasourceVariablesPayload.model_validate(console_ns.payload or {}).model_dump() rag_pipeline_service = RagPipelineService() workflow_node_execution = rag_pipeline_service.set_datasource_variables( diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index b2998a8d3e..335c8f6030 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -1,5 +1,10 @@ -from flask_restx import Resource, fields, reqparse +from typing import Literal +from flask import request +from flask_restx import Resource +from pydantic import BaseModel + +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.datasets.error import WebsiteCrawlError from controllers.console.wraps import account_initialization_required, setup_required @@ -7,48 +12,35 @@ from libs.login import login_required from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService +class WebsiteCrawlPayload(BaseModel): + provider: Literal["firecrawl", "watercrawl", "jinareader"] + url: str + options: dict[str, object] + + +class WebsiteCrawlStatusQuery(BaseModel): + provider: Literal["firecrawl", "watercrawl", "jinareader"] + + +register_schema_models(console_ns, WebsiteCrawlPayload, WebsiteCrawlStatusQuery) + + @console_ns.route("/website/crawl") class WebsiteCrawlApi(Resource): @console_ns.doc("crawl_website") @console_ns.doc(description="Crawl website content") - @console_ns.expect( - console_ns.model( - "WebsiteCrawlRequest", - { - "provider": fields.String( - required=True, - description="Crawl provider (firecrawl/watercrawl/jinareader)", - enum=["firecrawl", "watercrawl", "jinareader"], - ), - "url": fields.String(required=True, description="URL to crawl"), - "options": fields.Raw(required=True, description="Crawl options"), - }, - ) - ) + @console_ns.expect(console_ns.models[WebsiteCrawlPayload.__name__]) @console_ns.response(200, "Website crawl initiated successfully") @console_ns.response(400, "Invalid crawl parameters") @setup_required @login_required @account_initialization_required def post(self): - parser = ( - reqparse.RequestParser() - .add_argument( - "provider", - type=str, - choices=["firecrawl", "watercrawl", "jinareader"], - required=True, - nullable=True, - location="json", - ) - .add_argument("url", type=str, required=True, nullable=True, location="json") - .add_argument("options", type=dict, required=True, nullable=True, location="json") - ) - args = parser.parse_args() + payload = WebsiteCrawlPayload.model_validate(console_ns.payload or {}) # Create typed request and validate try: - api_request = WebsiteCrawlApiRequest.from_args(args) + api_request = WebsiteCrawlApiRequest.from_args(payload.model_dump()) except ValueError as e: raise WebsiteCrawlError(str(e)) @@ -65,6 +57,7 @@ class WebsiteCrawlStatusApi(Resource): @console_ns.doc("get_crawl_status") @console_ns.doc(description="Get website crawl status") @console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"}) + @console_ns.expect(console_ns.models[WebsiteCrawlStatusQuery.__name__]) @console_ns.response(200, "Crawl status retrieved successfully") @console_ns.response(404, "Crawl job not found") @console_ns.response(400, "Invalid provider") @@ -72,14 +65,11 @@ class WebsiteCrawlStatusApi(Resource): @login_required @account_initialization_required def get(self, job_id: str): - parser = reqparse.RequestParser().add_argument( - "provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args" - ) - args = parser.parse_args() + args = WebsiteCrawlStatusQuery.model_validate(request.args.to_dict()) # Create typed request and validate try: - api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id) + api_request = WebsiteCrawlStatusApiRequest.from_args(args.model_dump(), job_id) except ValueError as e: raise WebsiteCrawlError(str(e)) diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 2a248cf20d..0311db1584 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,9 +1,11 @@ import logging from flask import request +from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services +from controllers.common.schema import register_schema_model from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -31,6 +33,16 @@ from .. import console_ns logger = logging.getLogger(__name__) +class TextToAudioPayload(BaseModel): + message_id: str | None = None + voice: str | None = None + text: str | None = None + streaming: bool | None = Field(default=None, description="Enable streaming response") + + +register_schema_model(console_ns, TextToAudioPayload) + + @console_ns.route( "/installed-apps//audio-to-text", endpoint="installed_app_audio", @@ -76,23 +88,15 @@ class ChatAudioApi(InstalledAppResource): endpoint="installed_app_text", ) class ChatTextApi(InstalledAppResource): + @console_ns.expect(console_ns.models[TextToAudioPayload.__name__]) def post(self, installed_app): - from flask_restx import reqparse - app_model = installed_app.app try: - parser = ( - reqparse.RequestParser() - .add_argument("message_id", type=str, required=False, location="json") - .add_argument("voice", type=str, location="json") - .add_argument("text", type=str, location="json") - .add_argument("streaming", type=bool, location="json") - ) - args = parser.parse_args() + payload = TextToAudioPayload.model_validate(console_ns.payload or {}) - message_id = args.get("message_id", None) - text = args.get("text", None) - voice = args.get("voice", None) + message_id = payload.message_id + text = payload.text + voice = payload.voice response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) return response diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 52d6426e7f..78e9a87a3d 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,9 +1,12 @@ import logging +from typing import Any, Literal +from uuid import UUID -from flask_restx import reqparse +from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.console.app.error import ( AppUnavailableError, CompletionRequestError, @@ -25,7 +28,6 @@ from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs import helper from libs.datetime_utils import naive_utc_now -from libs.helper import uuid_value from libs.login import current_user from models import Account from models.model import AppMode @@ -38,28 +40,42 @@ from .. import console_ns logger = logging.getLogger(__name__) +class CompletionMessagePayload(BaseModel): + inputs: dict[str, Any] + query: str = "" + files: list[dict[str, Any]] | None = None + response_mode: Literal["blocking", "streaming"] | None = None + retriever_from: str = Field(default="explore_app") + + +class ChatMessagePayload(BaseModel): + inputs: dict[str, Any] + query: str + files: list[dict[str, Any]] | None = None + conversation_id: UUID | None = None + parent_message_id: UUID | None = None + retriever_from: str = Field(default="explore_app") + + +register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) + + # define completion api for user @console_ns.route( "/installed-apps//completion-messages", endpoint="installed_app_completion", ) class CompletionApi(InstalledAppResource): + @console_ns.expect(console_ns.models[CompletionMessagePayload.__name__]) def post(self, installed_app): app_model = installed_app.app if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, location="json", default="") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") - ) - args = parser.parse_args() + payload = CompletionMessagePayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False installed_app.last_used_at = naive_utc_now() @@ -123,22 +139,15 @@ class CompletionStopApi(InstalledAppResource): endpoint="installed_app_chat_completion", ) class ChatApi(InstalledAppResource): + @console_ns.expect(console_ns.models[ChatMessagePayload.__name__]) def post(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, required=True, location="json") - .add_argument("files", type=list, required=False, location="json") - .add_argument("conversation_id", type=uuid_value, location="json") - .add_argument("parent_message_id", type=uuid_value, required=False, location="json") - .add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") - ) - args = parser.parse_args() + payload = ChatMessagePayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) args["auto_generate_name"] = False diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 5a39363cc2..157d5a135b 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,14 +1,18 @@ -from flask_restx import marshal_with, reqparse -from flask_restx.inputs import int_range +from typing import Any +from uuid import UUID + +from flask import request +from flask_restx import marshal_with +from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields -from libs.helper import uuid_value from libs.login import current_user from models import Account from models.model import AppMode @@ -19,29 +23,44 @@ from services.web_conversation_service import WebConversationService from .. import console_ns +class ConversationListQuery(BaseModel): + last_id: UUID | None = None + limit: int = Field(default=20, ge=1, le=100) + pinned: bool | None = None + + +class ConversationRenamePayload(BaseModel): + name: str + auto_generate: bool = False + + +register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload) + + @console_ns.route( "/installed-apps//conversations", endpoint="installed_app_conversations", ) class ConversationListApi(InstalledAppResource): @marshal_with(conversation_infinite_scroll_pagination_fields) + @console_ns.expect(console_ns.models[ConversationListQuery.__name__]) def get(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - .add_argument("pinned", type=str, choices=["true", "false", None], location="args") - ) - args = parser.parse_args() - - pinned = None - if "pinned" in args and args["pinned"] is not None: - pinned = args["pinned"] == "true" + raw_args: dict[str, Any] = { + "last_id": request.args.get("last_id"), + "limit": request.args.get("limit", default=20, type=int), + "pinned": request.args.get("pinned"), + } + if raw_args["last_id"] is None: + raw_args["last_id"] = None + pinned_value = raw_args["pinned"] + if isinstance(pinned_value, str): + raw_args["pinned"] = pinned_value == "true" + args = ConversationListQuery.model_validate(raw_args) try: if not isinstance(current_user, Account): @@ -51,10 +70,10 @@ class ConversationListApi(InstalledAppResource): session=session, app_model=app_model, user=current_user, - last_id=args["last_id"], - limit=args["limit"], + last_id=str(args.last_id) if args.last_id else None, + limit=args.limit, invoke_from=InvokeFrom.EXPLORE, - pinned=pinned, + pinned=args.pinned, ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -88,6 +107,7 @@ class ConversationApi(InstalledAppResource): ) class ConversationRenameApi(InstalledAppResource): @marshal_with(simple_conversation_fields) + @console_ns.expect(console_ns.models[ConversationRenamePayload.__name__]) def post(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) @@ -96,18 +116,13 @@ class ConversationRenameApi(InstalledAppResource): conversation_id = str(c_id) - parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=False, location="json") - .add_argument("auto_generate", type=bool, required=False, default=False, location="json") - ) - args = parser.parse_args() + payload = ConversationRenamePayload.model_validate(console_ns.payload or {}) try: if not isinstance(current_user, Account): raise ValueError("current_user must be an Account instance") return ConversationService.rename( - app_model, conversation_id, current_user, args["name"], args["auto_generate"] + app_model, conversation_id, current_user, payload.name, payload.auto_generate ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index db854e09bb..229b7c8865 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,9 +1,13 @@ import logging +from typing import Literal +from uuid import UUID -from flask_restx import marshal_with, reqparse -from flask_restx.inputs import int_range +from flask import request +from flask_restx import marshal_with +from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError, NotFound +from controllers.common.schema import register_schema_models from controllers.console.app.error import ( AppMoreLikeThisDisabledError, CompletionRequestError, @@ -22,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from libs import helper -from libs.helper import uuid_value from libs.login import current_account_with_tenant from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -40,12 +43,31 @@ from .. import console_ns logger = logging.getLogger(__name__) +class MessageListQuery(BaseModel): + conversation_id: UUID + first_id: UUID | None = None + limit: int = Field(default=20, ge=1, le=100) + + +class MessageFeedbackPayload(BaseModel): + rating: Literal["like", "dislike"] | None = None + content: str | None = None + + +class MoreLikeThisQuery(BaseModel): + response_mode: Literal["blocking", "streaming"] + + +register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, MoreLikeThisQuery) + + @console_ns.route( "/installed-apps//messages", endpoint="installed_app_messages", ) class MessageListApi(InstalledAppResource): @marshal_with(message_infinite_scroll_pagination_fields) + @console_ns.expect(console_ns.models[MessageListQuery.__name__]) def get(self, installed_app): current_user, _ = current_account_with_tenant() app_model = installed_app.app @@ -53,18 +75,15 @@ class MessageListApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - - parser = ( - reqparse.RequestParser() - .add_argument("conversation_id", required=True, type=uuid_value, location="args") - .add_argument("first_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - ) - args = parser.parse_args() + args = MessageListQuery.model_validate(request.args.to_dict()) try: return MessageService.pagination_by_first_id( - app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] + app_model, + current_user, + str(args.conversation_id), + str(args.first_id) if args.first_id else None, + args.limit, ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -77,26 +96,22 @@ class MessageListApi(InstalledAppResource): endpoint="installed_app_message_feedback", ) class MessageFeedbackApi(InstalledAppResource): + @console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__]) def post(self, installed_app, message_id): current_user, _ = current_account_with_tenant() app_model = installed_app.app message_id = str(message_id) - parser = ( - reqparse.RequestParser() - .add_argument("rating", type=str, choices=["like", "dislike", None], location="json") - .add_argument("content", type=str, location="json") - ) - args = parser.parse_args() + payload = MessageFeedbackPayload.model_validate(console_ns.payload or {}) try: MessageService.create_feedback( app_model=app_model, message_id=message_id, user=current_user, - rating=args.get("rating"), - content=args.get("content"), + rating=payload.rating, + content=payload.content, ) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -109,6 +124,7 @@ class MessageFeedbackApi(InstalledAppResource): endpoint="installed_app_more_like_this", ) class MessageMoreLikeThisApi(InstalledAppResource): + @console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__]) def get(self, installed_app, message_id): current_user, _ = current_account_with_tenant() app_model = installed_app.app @@ -117,12 +133,9 @@ class MessageMoreLikeThisApi(InstalledAppResource): message_id = str(message_id) - parser = reqparse.RequestParser().add_argument( - "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" - ) - args = parser.parse_args() + args = MoreLikeThisQuery.model_validate(request.args.to_dict()) - streaming = args["response_mode"] == "streaming" + streaming = args.response_mode == "streaming" try: response = AppGenerateService.generate_more_like_this( diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 9775c951f7..6a9e274a0e 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,16 +1,33 @@ -from flask_restx import fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from uuid import UUID + +from flask import request +from flask_restx import fields, marshal_with +from pydantic import BaseModel, Field from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import message_file_fields -from libs.helper import TimestampField, uuid_value +from libs.helper import TimestampField from libs.login import current_account_with_tenant from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService + +class SavedMessageListQuery(BaseModel): + last_id: UUID | None = None + limit: int = Field(default=20, ge=1, le=100) + + +class SavedMessageCreatePayload(BaseModel): + message_id: UUID + + +register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload) + + feedback_fields = {"rating": fields.String} message_fields = { @@ -33,32 +50,33 @@ class SavedMessageListApi(InstalledAppResource): } @marshal_with(saved_message_infinite_scroll_pagination_fields) + @console_ns.expect(console_ns.models[SavedMessageListQuery.__name__]) def get(self, installed_app): current_user, _ = current_account_with_tenant() app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + args = SavedMessageListQuery.model_validate(request.args.to_dict()) + + return SavedMessageService.pagination_by_last_id( + app_model, + current_user, + str(args.last_id) if args.last_id else None, + args.limit, ) - args = parser.parse_args() - - return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) + @console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__]) def post(self, installed_app): current_user, _ = current_account_with_tenant() app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json") - args = parser.parse_args() + payload = SavedMessageCreatePayload.model_validate(console_ns.payload or {}) try: - SavedMessageService.save(app_model, current_user, args["message_id"]) + SavedMessageService.save(app_model, current_user, str(payload.message_id)) except MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 125f603a5a..d679d0722d 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,8 +1,10 @@ import logging +from typing import Any -from flask_restx import reqparse +from pydantic import BaseModel from werkzeug.exceptions import InternalServerError +from controllers.common.schema import register_schema_model from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -32,8 +34,17 @@ from .. import console_ns logger = logging.getLogger(__name__) +class WorkflowRunPayload(BaseModel): + inputs: dict[str, Any] + files: list[dict[str, Any]] | None = None + + +register_schema_model(console_ns, WorkflowRunPayload) + + @console_ns.route("/installed-apps//workflows/run") class InstalledAppWorkflowRunApi(InstalledAppResource): + @console_ns.expect(console_ns.models[WorkflowRunPayload.__name__]) def post(self, installed_app: InstalledApp): """ Run workflow @@ -46,12 +57,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("files", type=list, required=False, location="json") - ) - args = parser.parse_args() + payload = WorkflowRunPayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) try: response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True diff --git a/api/controllers/inner_api/mail.py b/api/controllers/inner_api/mail.py index 7e40d81706..885ab7b78d 100644 --- a/api/controllers/inner_api/mail.py +++ b/api/controllers/inner_api/mail.py @@ -1,29 +1,38 @@ -from flask_restx import Resource, reqparse +from typing import Any +from flask_restx import Resource +from pydantic import BaseModel, Field + +from controllers.common.schema import register_schema_model from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only from tasks.mail_inner_task import send_inner_email_task -_mail_parser = ( - reqparse.RequestParser() - .add_argument("to", type=str, action="append", required=True) - .add_argument("subject", type=str, required=True) - .add_argument("body", type=str, required=True) - .add_argument("substitutions", type=dict, required=False) -) + +class InnerMailPayload(BaseModel): + to: list[str] = Field(description="Recipient email addresses", min_length=1) + subject: str + body: str + substitutions: dict[str, Any] | None = None + + +register_schema_model(inner_api_ns, InnerMailPayload) class BaseMail(Resource): """Shared logic for sending an inner email.""" + @inner_api_ns.doc("send_inner_mail") + @inner_api_ns.doc(description="Send internal email") + @inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__]) def post(self): - args = _mail_parser.parse_args() - send_inner_email_task.delay( # type: ignore - to=args["to"], - subject=args["subject"], - body=args["body"], - substitutions=args["substitutions"], + args = InnerMailPayload.model_validate(inner_api_ns.payload or {}) + send_inner_email_task.delay( + to=args.to, + subject=args.subject, + body=args.body, + substitutions=args.substitutions, # type: ignore ) return {"message": "success"}, 200 @@ -34,7 +43,7 @@ class EnterpriseMail(BaseMail): @inner_api_ns.doc("send_enterprise_mail") @inner_api_ns.doc(description="Send internal email for enterprise features") - @inner_api_ns.expect(_mail_parser) + @inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__]) @inner_api_ns.doc( responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"} ) @@ -56,7 +65,7 @@ class BillingMail(BaseMail): @inner_api_ns.doc("send_billing_mail") @inner_api_ns.doc(description="Send internal email for billing notifications") - @inner_api_ns.expect(_mail_parser) + @inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__]) @inner_api_ns.doc( responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"} ) diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 2a57bb745b..edf3ac393c 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -1,10 +1,9 @@ from collections.abc import Callable from functools import wraps -from typing import ParamSpec, TypeVar, cast +from typing import ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in -from flask_restx import reqparse from pydantic import BaseModel from sqlalchemy.orm import Session @@ -17,6 +16,11 @@ P = ParamSpec("P") R = TypeVar("R") +class TenantUserPayload(BaseModel): + tenant_id: str + user_id: str + + def get_user(tenant_id: str, user_id: str | None) -> EndUser: """ Get current user @@ -67,58 +71,45 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: return user_model -def get_user_tenant(view: Callable[P, R] | None = None): - def decorator(view_func: Callable[P, R]): - @wraps(view_func) - def decorated_view(*args: P.args, **kwargs: P.kwargs): - # fetch json body - parser = ( - reqparse.RequestParser() - .add_argument("tenant_id", type=str, required=True, location="json") - .add_argument("user_id", type=str, required=True, location="json") - ) +def get_user_tenant(view_func: Callable[P, R]): + @wraps(view_func) + def decorated_view(*args: P.args, **kwargs: P.kwargs): + payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {}) - p = parser.parse_args() + user_id = payload.user_id + tenant_id = payload.tenant_id - user_id = cast(str, p.get("user_id")) - tenant_id = cast(str, p.get("tenant_id")) + if not tenant_id: + raise ValueError("tenant_id is required") - if not tenant_id: - raise ValueError("tenant_id is required") + if not user_id: + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID - if not user_id: - user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID - - try: - tenant_model = ( - db.session.query(Tenant) - .where( - Tenant.id == tenant_id, - ) - .first() + try: + tenant_model = ( + db.session.query(Tenant) + .where( + Tenant.id == tenant_id, ) - except Exception: - raise ValueError("tenant not found") + .first() + ) + except Exception: + raise ValueError("tenant not found") - if not tenant_model: - raise ValueError("tenant not found") + if not tenant_model: + raise ValueError("tenant not found") - kwargs["tenant_model"] = tenant_model + kwargs["tenant_model"] = tenant_model - user = get_user(tenant_id, user_id) - kwargs["user_model"] = user + user = get_user(tenant_id, user_id) + kwargs["user_model"] = user - current_app.login_manager._update_request_context_with_user(user) # type: ignore - user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore + current_app.login_manager._update_request_context_with_user(user) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore - return view_func(*args, **kwargs) + return view_func(*args, **kwargs) - return decorated_view - - if view is None: - return decorator - else: - return decorator(view) + return decorated_view def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]): diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 8391a15919..a5746abafa 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -1,7 +1,9 @@ import json -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel +from controllers.common.schema import register_schema_models from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns from controllers.inner_api.wraps import enterprise_inner_api_only @@ -11,12 +13,25 @@ from models import Account from services.account_service import TenantService +class WorkspaceCreatePayload(BaseModel): + name: str + owner_email: str + + +class WorkspaceOwnerlessPayload(BaseModel): + name: str + + +register_schema_models(inner_api_ns, WorkspaceCreatePayload, WorkspaceOwnerlessPayload) + + @inner_api_ns.route("/enterprise/workspace") class EnterpriseWorkspace(Resource): @setup_required @enterprise_inner_api_only @inner_api_ns.doc("create_enterprise_workspace") @inner_api_ns.doc(description="Create a new enterprise workspace with owner assignment") + @inner_api_ns.expect(inner_api_ns.models[WorkspaceCreatePayload.__name__]) @inner_api_ns.doc( responses={ 200: "Workspace created successfully", @@ -25,18 +40,13 @@ class EnterpriseWorkspace(Resource): } ) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=True, location="json") - .add_argument("owner_email", type=str, required=True, location="json") - ) - args = parser.parse_args() + args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {}) - account = db.session.query(Account).filter_by(email=args["owner_email"]).first() + account = db.session.query(Account).filter_by(email=args.owner_email).first() if account is None: return {"message": "owner account not found."}, 404 - tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True) + tenant = TenantService.create_tenant(args.name, is_from_dashboard=True) TenantService.create_tenant_member(tenant, account, role="owner") tenant_was_created.send(tenant) @@ -62,6 +72,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource): @enterprise_inner_api_only @inner_api_ns.doc("create_enterprise_workspace_ownerless") @inner_api_ns.doc(description="Create a new enterprise workspace without initial owner assignment") + @inner_api_ns.expect(inner_api_ns.models[WorkspaceOwnerlessPayload.__name__]) @inner_api_ns.doc( responses={ 200: "Workspace created successfully", @@ -70,10 +81,9 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource): } ) def post(self): - parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json") - args = parser.parse_args() + args = WorkspaceOwnerlessPayload.model_validate(inner_api_ns.payload or {}) - tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True) + tenant = TenantService.create_tenant(args.name, is_from_dashboard=True) tenant_was_created.send(tenant) diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 8d8fe6b3a8..90137a10ba 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -1,10 +1,11 @@ -from typing import Union +from typing import Any, Union from flask import Response -from flask_restx import Resource, reqparse -from pydantic import ValidationError +from flask_restx import Resource +from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import Session +from controllers.common.schema import register_schema_model from controllers.console.app.mcp_server import AppMCPServerStatus from controllers.mcp import mcp_ns from core.app.app_config.entities import VariableEntity @@ -24,27 +25,19 @@ class MCPRequestError(Exception): super().__init__(message) -def int_or_str(value): - """Validate that a value is either an integer or string.""" - if isinstance(value, (int, str)): - return value - else: - return None +class MCPRequestPayload(BaseModel): + jsonrpc: str = Field(description="JSON-RPC version (should be '2.0')") + method: str = Field(description="The method to invoke") + params: dict[str, Any] | None = Field(default=None, description="Parameters for the method") + id: int | str | None = Field(default=None, description="Request ID for tracking responses") -# Define parser for both documentation and validation -mcp_request_parser = ( - reqparse.RequestParser() - .add_argument("jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')") - .add_argument("method", type=str, required=True, location="json", help="The method to invoke") - .add_argument("params", type=dict, required=False, location="json", help="Parameters for the method") - .add_argument("id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses") -) +register_schema_model(mcp_ns, MCPRequestPayload) @mcp_ns.route("/server//mcp") class MCPAppApi(Resource): - @mcp_ns.expect(mcp_request_parser) + @mcp_ns.expect(mcp_ns.models[MCPRequestPayload.__name__]) @mcp_ns.doc("handle_mcp_request") @mcp_ns.doc(description="Handle Model Context Protocol (MCP) requests for a specific server") @mcp_ns.doc(params={"server_code": "Unique identifier for the MCP server"}) @@ -70,9 +63,9 @@ class MCPAppApi(Resource): Raises: ValidationError: Invalid request format or parameters """ - args = mcp_request_parser.parse_args() - request_id: Union[int, str] | None = args.get("id") - mcp_request = self._parse_mcp_request(args) + args = MCPRequestPayload.model_validate(mcp_ns.payload or {}) + request_id: Union[int, str] | None = args.id + mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True)) with Session(db.engine, expire_on_commit=False) as session: # Get MCP server and app diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index f26718555a..63c373b50f 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,9 +1,11 @@ from typing import Literal from flask import request -from flask_restx import Api, Namespace, Resource, fields, reqparse +from flask_restx import Api, Namespace, Resource, fields from flask_restx.api import HTTPStatus +from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_models from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token @@ -12,26 +14,24 @@ from fields.annotation_fields import annotation_fields, build_annotation_model from models.model import App from services.annotation_service import AppAnnotationService -# Define parsers for annotation API -annotation_create_parser = ( - reqparse.RequestParser() - .add_argument("question", required=True, type=str, location="json", help="Annotation question") - .add_argument("answer", required=True, type=str, location="json", help="Annotation answer") -) -annotation_reply_action_parser = ( - reqparse.RequestParser() - .add_argument( - "score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching" - ) - .add_argument("embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name") - .add_argument("embedding_model_name", required=True, type=str, location="json", help="Embedding model name") -) +class AnnotationCreatePayload(BaseModel): + question: str = Field(description="Annotation question") + answer: str = Field(description="Annotation answer") + + +class AnnotationReplyActionPayload(BaseModel): + score_threshold: float = Field(description="Score threshold for annotation matching") + embedding_provider_name: str = Field(description="Embedding provider name") + embedding_model_name: str = Field(description="Embedding model name") + + +register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload) @service_api_ns.route("/apps/annotation-reply/") class AnnotationReplyActionApi(Resource): - @service_api_ns.expect(annotation_reply_action_parser) + @service_api_ns.expect(service_api_ns.models[AnnotationReplyActionPayload.__name__]) @service_api_ns.doc("annotation_reply_action") @service_api_ns.doc(description="Enable or disable annotation reply feature") @service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"}) @@ -44,7 +44,7 @@ class AnnotationReplyActionApi(Resource): @validate_app_token def post(self, app_model: App, action: Literal["enable", "disable"]): """Enable or disable annotation reply feature.""" - args = annotation_reply_action_parser.parse_args() + args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump() if action == "enable": result = AppAnnotationService.enable_app_annotation(args, app_model.id) elif action == "disable": @@ -126,7 +126,7 @@ class AnnotationListApi(Resource): "page": page, } - @service_api_ns.expect(annotation_create_parser) + @service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__]) @service_api_ns.doc("create_annotation") @service_api_ns.doc(description="Create a new annotation") @service_api_ns.doc( @@ -139,14 +139,14 @@ class AnnotationListApi(Resource): @service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED) def post(self, app_model: App): """Create a new annotation.""" - args = annotation_create_parser.parse_args() + args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) return annotation, 201 @service_api_ns.route("/apps/annotations/") class AnnotationUpdateDeleteApi(Resource): - @service_api_ns.expect(annotation_create_parser) + @service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__]) @service_api_ns.doc("update_annotation") @service_api_ns.doc(description="Update an existing annotation") @service_api_ns.doc(params={"annotation_id": "Annotation ID"}) @@ -163,7 +163,7 @@ class AnnotationUpdateDeleteApi(Resource): @service_api_ns.marshal_with(build_annotation_model(service_api_ns)) def put(self, app_model: App, annotation_id: str): """Update an existing annotation.""" - args = annotation_create_parser.parse_args() + args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) return annotation diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index c069a7ddfb..e383920460 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -1,10 +1,12 @@ import logging from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services +from controllers.common.schema import register_schema_model from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( AppUnavailableError, @@ -84,19 +86,19 @@ class AudioApi(Resource): raise InternalServerError() -# Define parser for text-to-audio API -text_to_audio_parser = ( - reqparse.RequestParser() - .add_argument("message_id", type=str, required=False, location="json", help="Message ID") - .add_argument("voice", type=str, location="json", help="Voice to use for TTS") - .add_argument("text", type=str, location="json", help="Text to convert to audio") - .add_argument("streaming", type=bool, location="json", help="Enable streaming response") -) +class TextToAudioPayload(BaseModel): + message_id: str | None = Field(default=None, description="Message ID") + voice: str | None = Field(default=None, description="Voice to use for TTS") + text: str | None = Field(default=None, description="Text to convert to audio") + streaming: bool | None = Field(default=None, description="Enable streaming response") + + +register_schema_model(service_api_ns, TextToAudioPayload) @service_api_ns.route("/text-to-audio") class TextApi(Resource): - @service_api_ns.expect(text_to_audio_parser) + @service_api_ns.expect(service_api_ns.models[TextToAudioPayload.__name__]) @service_api_ns.doc("text_to_audio") @service_api_ns.doc(description="Convert text to audio using text-to-speech") @service_api_ns.doc( @@ -114,11 +116,11 @@ class TextApi(Resource): Converts the provided text to audio using the specified voice. """ try: - args = text_to_audio_parser.parse_args() + payload = TextToAudioPayload.model_validate(service_api_ns.payload or {}) - message_id = args.get("message_id", None) - text = args.get("text", None) - voice = args.get("voice", None) + message_id = payload.message_id + text = payload.text + voice = payload.voice response = AudioService.transcript_tts( app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id ) diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index c5dd919759..a037fe9254 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,10 +1,14 @@ import logging +from typing import Any, Literal +from uuid import UUID from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( AppUnavailableError, @@ -26,7 +30,6 @@ from core.errors.error import ( from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from libs import helper -from libs.helper import uuid_value from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService @@ -36,40 +39,31 @@ from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) -# Define parser for completion API -completion_parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for completion") - .add_argument("query", type=str, location="json", default="", help="The query string") - .add_argument("files", type=list, required=False, location="json", help="List of file attachments") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode") - .add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source") -) +class CompletionRequestPayload(BaseModel): + inputs: dict[str, Any] + query: str = Field(default="") + files: list[dict[str, Any]] | None = None + response_mode: Literal["blocking", "streaming"] | None = None + retriever_from: str = Field(default="dev") -# Define parser for chat API -chat_parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat") - .add_argument("query", type=str, required=True, location="json", help="The chat query") - .add_argument("files", type=list, required=False, location="json", help="List of file attachments") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode") - .add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID") - .add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source") - .add_argument( - "auto_generate_name", - type=bool, - required=False, - default=True, - location="json", - help="Auto generate conversation name", - ) - .add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat") -) + +class ChatRequestPayload(BaseModel): + inputs: dict[str, Any] + query: str + files: list[dict[str, Any]] | None = None + response_mode: Literal["blocking", "streaming"] | None = None + conversation_id: UUID | None = None + retriever_from: str = Field(default="dev") + auto_generate_name: bool = Field(default=True, description="Auto generate conversation name") + workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat") + + +register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload) @service_api_ns.route("/completion-messages") class CompletionApi(Resource): - @service_api_ns.expect(completion_parser) + @service_api_ns.expect(service_api_ns.models[CompletionRequestPayload.__name__]) @service_api_ns.doc("create_completion") @service_api_ns.doc(description="Create a completion for the given prompt") @service_api_ns.doc( @@ -91,12 +85,13 @@ class CompletionApi(Resource): if app_model.mode != AppMode.COMPLETION: raise AppUnavailableError() - args = completion_parser.parse_args() + payload = CompletionRequestPayload.model_validate(service_api_ns.payload or {}) external_trace_id = get_external_trace_id(request) + args = payload.model_dump(exclude_none=True) if external_trace_id: args["external_trace_id"] = external_trace_id - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False @@ -162,7 +157,7 @@ class CompletionStopApi(Resource): @service_api_ns.route("/chat-messages") class ChatApi(Resource): - @service_api_ns.expect(chat_parser) + @service_api_ns.expect(service_api_ns.models[ChatRequestPayload.__name__]) @service_api_ns.doc("create_chat_message") @service_api_ns.doc(description="Send a message in a chat conversation") @service_api_ns.doc( @@ -186,13 +181,14 @@ class ChatApi(Resource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - args = chat_parser.parse_args() + payload = ChatRequestPayload.model_validate(service_api_ns.payload or {}) external_trace_id = get_external_trace_id(request) + args = payload.model_dump(exclude_none=True) if external_trace_id: args["external_trace_id"] = external_trace_id - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" try: response = AppGenerateService.generate( diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index c4e23dd2e7..724ad3448d 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,10 +1,15 @@ -from flask_restx import Resource, reqparse +from typing import Any, Literal +from uuid import UUID + +from flask import request +from flask_restx import Resource from flask_restx._http import HTTPStatus -from flask_restx.inputs import int_range +from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound import services +from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token @@ -19,74 +24,44 @@ from fields.conversation_variable_fields import ( build_conversation_variable_infinite_scroll_pagination_model, build_conversation_variable_model, ) -from libs.helper import uuid_value from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService -# Define parsers for conversation APIs -conversation_list_parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args", help="Last conversation ID for pagination") - .add_argument( - "limit", - type=int_range(1, 100), - required=False, - default=20, - location="args", - help="Number of conversations to return", - ) - .add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", - help="Sort order for conversations", - ) -) -conversation_rename_parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=False, location="json", help="New conversation name") - .add_argument( - "auto_generate", - type=bool, - required=False, - default=False, - location="json", - help="Auto-generate conversation name", +class ConversationListQuery(BaseModel): + last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination") + limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return") + sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field( + default="-updated_at", description="Sort order for conversations" ) -) -conversation_variables_parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args", help="Last variable ID for pagination") - .add_argument( - "limit", - type=int_range(1, 100), - required=False, - default=20, - location="args", - help="Number of variables to return", - ) -) -conversation_variable_update_parser = reqparse.RequestParser().add_argument( - # using lambda is for passing the already-typed value without modification - # if no lambda, it will be converted to string - # the string cannot be converted using json.loads - "value", - required=True, - location="json", - type=lambda x: x, - help="New value for the conversation variable", +class ConversationRenamePayload(BaseModel): + name: str = Field(description="New conversation name") + auto_generate: bool = Field(default=False, description="Auto-generate conversation name") + + +class ConversationVariablesQuery(BaseModel): + last_id: UUID | None = Field(default=None, description="Last variable ID for pagination") + limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return") + + +class ConversationVariableUpdatePayload(BaseModel): + value: Any + + +register_schema_models( + service_api_ns, + ConversationListQuery, + ConversationRenamePayload, + ConversationVariablesQuery, + ConversationVariableUpdatePayload, ) @service_api_ns.route("/conversations") class ConversationApi(Resource): - @service_api_ns.expect(conversation_list_parser) + @service_api_ns.expect(service_api_ns.models[ConversationListQuery.__name__]) @service_api_ns.doc("list_conversations") @service_api_ns.doc(description="List all conversations for the current user") @service_api_ns.doc( @@ -107,7 +82,8 @@ class ConversationApi(Resource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - args = conversation_list_parser.parse_args() + query_args = ConversationListQuery.model_validate(request.args.to_dict()) + last_id = str(query_args.last_id) if query_args.last_id else None try: with Session(db.engine) as session: @@ -115,10 +91,10 @@ class ConversationApi(Resource): session=session, app_model=app_model, user=end_user, - last_id=args["last_id"], - limit=args["limit"], + last_id=last_id, + limit=query_args.limit, invoke_from=InvokeFrom.SERVICE_API, - sort_by=args["sort_by"], + sort_by=query_args.sort_by, ) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -155,7 +131,7 @@ class ConversationDetailApi(Resource): @service_api_ns.route("/conversations//name") class ConversationRenameApi(Resource): - @service_api_ns.expect(conversation_rename_parser) + @service_api_ns.expect(service_api_ns.models[ConversationRenamePayload.__name__]) @service_api_ns.doc("rename_conversation") @service_api_ns.doc(description="Rename a conversation or auto-generate a name") @service_api_ns.doc(params={"c_id": "Conversation ID"}) @@ -176,17 +152,17 @@ class ConversationRenameApi(Resource): conversation_id = str(c_id) - args = conversation_rename_parser.parse_args() + payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {}) try: - return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) + return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @service_api_ns.route("/conversations//variables") class ConversationVariablesApi(Resource): - @service_api_ns.expect(conversation_variables_parser) + @service_api_ns.expect(service_api_ns.models[ConversationVariablesQuery.__name__]) @service_api_ns.doc("list_conversation_variables") @service_api_ns.doc(description="List all variables for a conversation") @service_api_ns.doc(params={"c_id": "Conversation ID"}) @@ -211,11 +187,12 @@ class ConversationVariablesApi(Resource): conversation_id = str(c_id) - args = conversation_variables_parser.parse_args() + query_args = ConversationVariablesQuery.model_validate(request.args.to_dict()) + last_id = str(query_args.last_id) if query_args.last_id else None try: return ConversationService.get_conversational_variable( - app_model, conversation_id, end_user, args["limit"], args["last_id"] + app_model, conversation_id, end_user, query_args.limit, last_id ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -223,7 +200,7 @@ class ConversationVariablesApi(Resource): @service_api_ns.route("/conversations//variables/") class ConversationVariableDetailApi(Resource): - @service_api_ns.expect(conversation_variable_update_parser) + @service_api_ns.expect(service_api_ns.models[ConversationVariableUpdatePayload.__name__]) @service_api_ns.doc("update_conversation_variable") @service_api_ns.doc(description="Update a conversation variable's value") @service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"}) @@ -250,11 +227,11 @@ class ConversationVariableDetailApi(Resource): conversation_id = str(c_id) variable_id = str(variable_id) - args = conversation_variable_update_parser.parse_args() + payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {}) try: return ConversationService.update_conversation_variable( - app_model, conversation_id, variable_id, end_user, args["value"] + app_model, conversation_id, variable_id, end_user, payload.value ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py index b8e91f0657..60f422b88e 100644 --- a/api/controllers/service_api/app/file_preview.py +++ b/api/controllers/service_api/app/file_preview.py @@ -1,9 +1,11 @@ import logging from urllib.parse import quote -from flask import Response -from flask_restx import Resource, reqparse +from flask import Response, request +from flask_restx import Resource +from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_model from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( FileAccessDeniedError, @@ -17,10 +19,11 @@ from models.model import App, EndUser, Message, MessageFile, UploadFile logger = logging.getLogger(__name__) -# Define parser for file preview API -file_preview_parser = reqparse.RequestParser().add_argument( - "as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment" -) +class FilePreviewQuery(BaseModel): + as_attachment: bool = Field(default=False, description="Download as attachment") + + +register_schema_model(service_api_ns, FilePreviewQuery) @service_api_ns.route("/files//preview") @@ -32,7 +35,7 @@ class FilePreviewApi(Resource): Files can only be accessed if they belong to messages within the requesting app's context. """ - @service_api_ns.expect(file_preview_parser) + @service_api_ns.expect(service_api_ns.models[FilePreviewQuery.__name__]) @service_api_ns.doc("preview_file") @service_api_ns.doc(description="Preview or download a file uploaded via Service API") @service_api_ns.doc(params={"file_id": "UUID of the file to preview"}) @@ -55,7 +58,7 @@ class FilePreviewApi(Resource): file_id = str(file_id) # Parse query parameters - args = file_preview_parser.parse_args() + args = FilePreviewQuery.model_validate(request.args.to_dict()) # Validate file ownership and get file objects _, upload_file = self._validate_file_ownership(file_id, app_model.id) @@ -67,7 +70,7 @@ class FilePreviewApi(Resource): raise FileNotFoundError(f"Failed to load file content: {str(e)}") # Build response with appropriate headers - response = self._build_file_response(generator, upload_file, args["as_attachment"]) + response = self._build_file_response(generator, upload_file, args.as_attachment) return response diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index b8e5ed28e4..d342f4e661 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,11 +1,15 @@ import json import logging +from typing import Literal +from uuid import UUID -from flask_restx import Api, Namespace, Resource, fields, reqparse -from flask_restx.inputs import int_range +from flask import request +from flask_restx import Namespace, Resource, fields +from pydantic import BaseModel, Field from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token @@ -13,7 +17,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import build_message_file_model from fields.message_fields import build_agent_thought_model, build_feedback_model from fields.raws import FilesContainedField -from libs.helper import TimestampField, uuid_value +from libs.helper import TimestampField from models.model import App, AppMode, EndUser from services.errors.message import ( FirstMessageNotExistsError, @@ -25,42 +29,26 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) -# Define parsers for message APIs -message_list_parser = ( - reqparse.RequestParser() - .add_argument("conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID") - .add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination") - .add_argument( - "limit", - type=int_range(1, 100), - required=False, - default=20, - location="args", - help="Number of messages to return", - ) -) - -message_feedback_parser = ( - reqparse.RequestParser() - .add_argument("rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating") - .add_argument("content", type=str, location="json", help="Feedback content") -) - -feedback_list_parser = ( - reqparse.RequestParser() - .add_argument("page", type=int, default=1, location="args", help="Page number") - .add_argument( - "limit", - type=int_range(1, 101), - required=False, - default=20, - location="args", - help="Number of feedbacks per page", - ) -) +class MessageListQuery(BaseModel): + conversation_id: UUID + first_id: UUID | None = None + limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return") -def build_message_model(api_or_ns: Api | Namespace): +class MessageFeedbackPayload(BaseModel): + rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating") + content: str | None = Field(default=None, description="Feedback content") + + +class FeedbackListQuery(BaseModel): + page: int = Field(default=1, ge=1, description="Page number") + limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page") + + +register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery) + + +def build_message_model(api_or_ns: Namespace): """Build the message model for the API or Namespace.""" # First build the nested models feedback_model = build_feedback_model(api_or_ns) @@ -90,7 +78,7 @@ def build_message_model(api_or_ns: Api | Namespace): return api_or_ns.model("Message", message_fields) -def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): +def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace): """Build the message infinite scroll pagination model for the API or Namespace.""" # Build the nested message model first message_model = build_message_model(api_or_ns) @@ -105,7 +93,7 @@ def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): @service_api_ns.route("/messages") class MessageListApi(Resource): - @service_api_ns.expect(message_list_parser) + @service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__]) @service_api_ns.doc("list_messages") @service_api_ns.doc(description="List messages in a conversation") @service_api_ns.doc( @@ -126,11 +114,13 @@ class MessageListApi(Resource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - args = message_list_parser.parse_args() + query_args = MessageListQuery.model_validate(request.args.to_dict()) + conversation_id = str(query_args.conversation_id) + first_id = str(query_args.first_id) if query_args.first_id else None try: return MessageService.pagination_by_first_id( - app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] + app_model, end_user, conversation_id, first_id, query_args.limit ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -140,7 +130,7 @@ class MessageListApi(Resource): @service_api_ns.route("/messages//feedbacks") class MessageFeedbackApi(Resource): - @service_api_ns.expect(message_feedback_parser) + @service_api_ns.expect(service_api_ns.models[MessageFeedbackPayload.__name__]) @service_api_ns.doc("create_message_feedback") @service_api_ns.doc(description="Submit feedback for a message") @service_api_ns.doc(params={"message_id": "Message ID"}) @@ -159,15 +149,15 @@ class MessageFeedbackApi(Resource): """ message_id = str(message_id) - args = message_feedback_parser.parse_args() + payload = MessageFeedbackPayload.model_validate(service_api_ns.payload or {}) try: MessageService.create_feedback( app_model=app_model, message_id=message_id, user=end_user, - rating=args.get("rating"), - content=args.get("content"), + rating=payload.rating, + content=payload.content, ) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -177,7 +167,7 @@ class MessageFeedbackApi(Resource): @service_api_ns.route("/app/feedbacks") class AppGetFeedbacksApi(Resource): - @service_api_ns.expect(feedback_list_parser) + @service_api_ns.expect(service_api_ns.models[FeedbackListQuery.__name__]) @service_api_ns.doc("get_app_feedbacks") @service_api_ns.doc(description="Get all feedbacks for the application") @service_api_ns.doc( @@ -192,8 +182,8 @@ class AppGetFeedbacksApi(Resource): Returns paginated list of all feedback submitted for messages in this app. """ - args = feedback_list_parser.parse_args() - feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"]) + query_args = FeedbackListQuery.model_validate(request.args.to_dict()) + feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=query_args.page, limit=query_args.limit) return {"data": feedbacks} diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index af5eae463d..4964888fd6 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,12 +1,14 @@ import logging +from typing import Any, Literal from dateutil.parser import isoparse from flask import request -from flask_restx import Api, Namespace, Resource, fields, reqparse -from flask_restx.inputs import int_range +from flask_restx import Api, Namespace, Resource, fields +from pydantic import BaseModel, Field from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound +from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( CompletionRequestError, @@ -41,37 +43,25 @@ from services.workflow_app_service import WorkflowAppService logger = logging.getLogger(__name__) -# Define parsers for workflow APIs -workflow_run_parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") -) -workflow_log_parser = ( - reqparse.RequestParser() - .add_argument("keyword", type=str, location="args") - .add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") - .add_argument("created_at__before", type=str, location="args") - .add_argument("created_at__after", type=str, location="args") - .add_argument( - "created_by_end_user_session_id", - type=str, - location="args", - required=False, - default=None, - ) - .add_argument( - "created_by_account", - type=str, - location="args", - required=False, - default=None, - ) - .add_argument("page", type=int_range(1, 99999), default=1, location="args") - .add_argument("limit", type=int_range(1, 100), default=20, location="args") -) +class WorkflowRunPayload(BaseModel): + inputs: dict[str, Any] + files: list[dict[str, Any]] | None = None + response_mode: Literal["blocking", "streaming"] | None = None + + +class WorkflowLogQuery(BaseModel): + keyword: str | None = None + status: Literal["succeeded", "failed", "stopped"] | None = None + created_at__before: str | None = None + created_at__after: str | None = None + created_by_end_user_session_id: str | None = None + created_by_account: str | None = None + page: int = Field(default=1, ge=1, le=99999) + limit: int = Field(default=20, ge=1, le=100) + + +register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery) workflow_run_fields = { "id": fields.String, @@ -130,7 +120,7 @@ class WorkflowRunDetailApi(Resource): @service_api_ns.route("/workflows/run") class WorkflowRunApi(Resource): - @service_api_ns.expect(workflow_run_parser) + @service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__]) @service_api_ns.doc("run_workflow") @service_api_ns.doc(description="Execute a workflow") @service_api_ns.doc( @@ -154,11 +144,12 @@ class WorkflowRunApi(Resource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - args = workflow_run_parser.parse_args() + payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {}) + args = payload.model_dump(exclude_none=True) external_trace_id = get_external_trace_id(request) if external_trace_id: args["external_trace_id"] = external_trace_id - streaming = args.get("response_mode") == "streaming" + streaming = payload.response_mode == "streaming" try: response = AppGenerateService.generate( @@ -185,7 +176,7 @@ class WorkflowRunApi(Resource): @service_api_ns.route("/workflows//run") class WorkflowRunByIdApi(Resource): - @service_api_ns.expect(workflow_run_parser) + @service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__]) @service_api_ns.doc("run_workflow_by_id") @service_api_ns.doc(description="Execute a specific workflow by ID") @service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"}) @@ -209,7 +200,8 @@ class WorkflowRunByIdApi(Resource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - args = workflow_run_parser.parse_args() + payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {}) + args = payload.model_dump(exclude_none=True) # Add workflow_id to args for AppGenerateService args["workflow_id"] = workflow_id @@ -217,7 +209,7 @@ class WorkflowRunByIdApi(Resource): external_trace_id = get_external_trace_id(request) if external_trace_id: args["external_trace_id"] = external_trace_id - streaming = args.get("response_mode") == "streaming" + streaming = payload.response_mode == "streaming" try: response = AppGenerateService.generate( @@ -279,7 +271,7 @@ class WorkflowTaskStopApi(Resource): @service_api_ns.route("/workflows/logs") class WorkflowAppLogApi(Resource): - @service_api_ns.expect(workflow_log_parser) + @service_api_ns.expect(service_api_ns.models[WorkflowLogQuery.__name__]) @service_api_ns.doc("get_workflow_logs") @service_api_ns.doc(description="Get workflow execution logs") @service_api_ns.doc( @@ -295,14 +287,11 @@ class WorkflowAppLogApi(Resource): Returns paginated workflow execution logs with filtering options. """ - args = workflow_log_parser.parse_args() + args = WorkflowLogQuery.model_validate(request.args.to_dict()) - args.status = WorkflowExecutionStatus(args.status) if args.status else None - if args.created_at__before: - args.created_at__before = isoparse(args.created_at__before) - - if args.created_at__after: - args.created_at__after = isoparse(args.created_at__after) + status = WorkflowExecutionStatus(args.status) if args.status else None + created_at_before = isoparse(args.created_at__before) if args.created_at__before else None + created_at_after = isoparse(args.created_at__after) if args.created_at__after else None # get paginate workflow app logs workflow_app_service = WorkflowAppService() @@ -311,9 +300,9 @@ class WorkflowAppLogApi(Resource): session=session, app_model=app_model, keyword=args.keyword, - status=args.status, - created_at_before=args.created_at__before, - created_at_after=args.created_at__after, + status=status, + created_at_before=created_at_before, + created_at_after=created_at_after, page=args.page, limit=args.limit, created_by_end_user_session_id=args.created_by_end_user_session_id, diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 4cca3e6ce8..7692aeed23 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,10 +1,12 @@ from typing import Any, Literal, cast from flask import request -from flask_restx import marshal, reqparse +from flask_restx import marshal +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import Forbidden, NotFound import services +from controllers.common.schema import register_schema_models from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError @@ -18,173 +20,83 @@ from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import build_dataset_tag_fields from libs.login import current_user -from libs.validators import validate_description_length from models.account import Account -from models.dataset import Dataset, DatasetPermissionEnum +from models.dataset import DatasetPermissionEnum from models.provider_ids import ModelProviderID from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.tag_service import TagService -def _validate_name(name): - if not name or len(name) < 1 or len(name) > 40: - raise ValueError("Name must be between 1 to 40 characters.") - return name +class DatasetCreatePayload(BaseModel): + name: str = Field(..., min_length=1, max_length=40) + description: str = Field(default="", description="Dataset description (max 400 chars)", max_length=400) + indexing_technique: Literal["high_quality", "economy"] | None = None + permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME + external_knowledge_api_id: str | None = None + provider: str = "vendor" + external_knowledge_id: str | None = None + retrieval_model: RetrievalModel | None = None + embedding_model: str | None = None + embedding_model_provider: str | None = None -# Define parsers for dataset operations -dataset_create_parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - .add_argument( - "description", - type=validate_description_length, - nullable=True, - required=False, - default="", - ) - .add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - help="Invalid indexing technique.", - ) - .add_argument( - "permission", - type=str, - location="json", - choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), - help="Invalid permission.", - required=False, - nullable=False, - ) - .add_argument( - "external_knowledge_api_id", - type=str, - nullable=True, - required=False, - default="_validate_name", - ) - .add_argument( - "provider", - type=str, - nullable=True, - required=False, - default="vendor", - ) - .add_argument( - "external_knowledge_id", - type=str, - nullable=True, - required=False, - ) - .add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") - .add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") -) +class DatasetUpdatePayload(BaseModel): + name: str | None = Field(default=None, min_length=1, max_length=40) + description: str | None = Field(default=None, description="Dataset description (max 400 chars)", max_length=400) + indexing_technique: Literal["high_quality", "economy"] | None = None + permission: DatasetPermissionEnum | None = None + embedding_model: str | None = None + embedding_model_provider: str | None = None + retrieval_model: RetrievalModel | None = None + partial_member_list: list[str] | None = None + external_retrieval_model: dict[str, Any] | None = None + external_knowledge_id: str | None = None + external_knowledge_api_id: str | None = None -dataset_update_parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - .add_argument("description", location="json", store_missing=False, type=validate_description_length) - .add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help="Invalid indexing technique.", - ) - .add_argument( - "permission", - type=str, - location="json", - choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), - help="Invalid permission.", - ) - .add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") - .add_argument("embedding_model_provider", type=str, location="json", help="Invalid embedding model provider.") - .add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") - .add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") - .add_argument( - "external_retrieval_model", - type=dict, - required=False, - nullable=True, - location="json", - help="Invalid external retrieval model.", - ) - .add_argument( - "external_knowledge_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge id.", - ) - .add_argument( - "external_knowledge_api_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge api id.", - ) -) -tag_create_parser = reqparse.RequestParser().add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=lambda x: x - if x and 1 <= len(x) <= 50 - else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), -) +class TagNamePayload(BaseModel): + name: str = Field(..., min_length=1, max_length=50) -tag_update_parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=lambda x: x - if x and 1 <= len(x) <= 50 - else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), - ) - .add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) -) -tag_delete_parser = reqparse.RequestParser().add_argument( - "tag_id", nullable=False, required=True, help="Id of a tag.", type=str -) +class TagCreatePayload(TagNamePayload): + pass -tag_binding_parser = ( - reqparse.RequestParser() - .add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.") - .add_argument( - "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required." - ) -) -tag_unbinding_parser = ( - reqparse.RequestParser() - .add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") - .add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") +class TagUpdatePayload(TagNamePayload): + tag_id: str + + +class TagDeletePayload(BaseModel): + tag_id: str + + +class TagBindingPayload(BaseModel): + tag_ids: list[str] + target_id: str + + @field_validator("tag_ids") + @classmethod + def validate_tag_ids(cls, value: list[str]) -> list[str]: + if not value: + raise ValueError("Tag IDs is required.") + return value + + +class TagUnbindingPayload(BaseModel): + tag_id: str + target_id: str + + +register_schema_models( + service_api_ns, + DatasetCreatePayload, + DatasetUpdatePayload, + TagCreatePayload, + TagUpdatePayload, + TagDeletePayload, + TagBindingPayload, + TagUnbindingPayload, ) @@ -239,7 +151,7 @@ class DatasetListApi(DatasetApiResource): response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 - @service_api_ns.expect(dataset_create_parser) + @service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__]) @service_api_ns.doc("create_dataset") @service_api_ns.doc(description="Create a new dataset") @service_api_ns.doc( @@ -252,42 +164,41 @@ class DatasetListApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id): """Resource for creating datasets.""" - args = dataset_create_parser.parse_args() + payload = DatasetCreatePayload.model_validate(service_api_ns.payload or {}) - embedding_model_provider = args.get("embedding_model_provider") - embedding_model = args.get("embedding_model") + embedding_model_provider = payload.embedding_model_provider + embedding_model = payload.embedding_model if embedding_model_provider and embedding_model: DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) - retrieval_model = args.get("retrieval_model") + retrieval_model = payload.retrieval_model if ( retrieval_model - and retrieval_model.get("reranking_model") - and retrieval_model.get("reranking_model").get("reranking_provider_name") + and retrieval_model.reranking_model + and retrieval_model.reranking_model.reranking_provider_name + and retrieval_model.reranking_model.reranking_model_name ): DatasetService.check_reranking_model_setting( tenant_id, - retrieval_model.get("reranking_model").get("reranking_provider_name"), - retrieval_model.get("reranking_model").get("reranking_model_name"), + retrieval_model.reranking_model.reranking_provider_name, + retrieval_model.reranking_model.reranking_model_name, ) try: assert isinstance(current_user, Account) dataset = DatasetService.create_empty_dataset( tenant_id=tenant_id, - name=args["name"], - description=args["description"], - indexing_technique=args["indexing_technique"], + name=payload.name, + description=payload.description, + indexing_technique=payload.indexing_technique, account=current_user, - permission=args["permission"], - provider=args["provider"], - external_knowledge_api_id=args["external_knowledge_api_id"], - external_knowledge_id=args["external_knowledge_id"], - embedding_model_provider=args["embedding_model_provider"], - embedding_model_name=args["embedding_model"], - retrieval_model=RetrievalModel.model_validate(args["retrieval_model"]) - if args["retrieval_model"] is not None - else None, + permission=str(payload.permission) if payload.permission else None, + provider=payload.provider, + external_knowledge_api_id=payload.external_knowledge_api_id, + external_knowledge_id=payload.external_knowledge_id, + embedding_model_provider=payload.embedding_model_provider, + embedding_model_name=payload.embedding_model, + retrieval_model=payload.retrieval_model, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -353,7 +264,7 @@ class DatasetApi(DatasetApiResource): return data, 200 - @service_api_ns.expect(dataset_update_parser) + @service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__]) @service_api_ns.doc("update_dataset") @service_api_ns.doc(description="Update an existing dataset") @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @@ -372,36 +283,45 @@ class DatasetApi(DatasetApiResource): if dataset is None: raise NotFound("Dataset not found.") - args = dataset_update_parser.parse_args() - data = request.get_json() + payload_dict = service_api_ns.payload or {} + payload = DatasetUpdatePayload.model_validate(payload_dict) + update_data = payload.model_dump(exclude_unset=True) + if payload.permission is not None: + update_data["permission"] = str(payload.permission) + if payload.retrieval_model is not None: + update_data["retrieval_model"] = payload.retrieval_model.model_dump() # check embedding model setting - embedding_model_provider = data.get("embedding_model_provider") - embedding_model = data.get("embedding_model") - if data.get("indexing_technique") == "high_quality" or embedding_model_provider: + embedding_model_provider = payload.embedding_model_provider + embedding_model = payload.embedding_model + if payload.indexing_technique == "high_quality" or embedding_model_provider: if embedding_model_provider and embedding_model: DatasetService.check_embedding_model_setting( dataset.tenant_id, embedding_model_provider, embedding_model ) - retrieval_model = data.get("retrieval_model") + retrieval_model = payload.retrieval_model if ( retrieval_model - and retrieval_model.get("reranking_model") - and retrieval_model.get("reranking_model").get("reranking_provider_name") + and retrieval_model.reranking_model + and retrieval_model.reranking_model.reranking_provider_name + and retrieval_model.reranking_model.reranking_model_name ): DatasetService.check_reranking_model_setting( dataset.tenant_id, - retrieval_model.get("reranking_model").get("reranking_provider_name"), - retrieval_model.get("reranking_model").get("reranking_model_name"), + retrieval_model.reranking_model.reranking_provider_name, + retrieval_model.reranking_model.reranking_model_name, ) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( - current_user, dataset, data.get("permission"), data.get("partial_member_list") + current_user, + dataset, + str(payload.permission) if payload.permission else None, + payload.partial_member_list, ) - dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) + dataset = DatasetService.update_dataset(dataset_id_str, update_data, current_user) if dataset is None: raise NotFound("Dataset not found.") @@ -410,15 +330,10 @@ class DatasetApi(DatasetApiResource): assert isinstance(current_user, Account) tenant_id = current_user.current_tenant_id - if data.get("partial_member_list") and data.get("permission") == "partial_members": - DatasetPermissionService.update_partial_member_list( - tenant_id, dataset_id_str, data.get("partial_member_list") - ) + if payload.partial_member_list and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM: + DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list) # clear partial member list when permission is only_me or all_team_members - elif ( - data.get("permission") == DatasetPermissionEnum.ONLY_ME - or data.get("permission") == DatasetPermissionEnum.ALL_TEAM - ): + elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}: DatasetPermissionService.clear_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) @@ -556,7 +471,7 @@ class DatasetTagsApi(DatasetApiResource): return tags, 200 - @service_api_ns.expect(tag_create_parser) + @service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__]) @service_api_ns.doc("create_dataset_tag") @service_api_ns.doc(description="Add a knowledge type tag") @service_api_ns.doc( @@ -574,14 +489,13 @@ class DatasetTagsApi(DatasetApiResource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = tag_create_parser.parse_args() - args["type"] = "knowledge" - tag = TagService.save_tags(args) + payload = TagCreatePayload.model_validate(service_api_ns.payload or {}) + tag = TagService.save_tags({"name": payload.name, "type": "knowledge"}) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} return response, 200 - @service_api_ns.expect(tag_update_parser) + @service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__]) @service_api_ns.doc("update_dataset_tag") @service_api_ns.doc(description="Update a knowledge type tag") @service_api_ns.doc( @@ -598,10 +512,10 @@ class DatasetTagsApi(DatasetApiResource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = tag_update_parser.parse_args() - args["type"] = "knowledge" - tag_id = args["tag_id"] - tag = TagService.update_tags(args, tag_id) + payload = TagUpdatePayload.model_validate(service_api_ns.payload or {}) + params = {"name": payload.name, "type": "knowledge"} + tag_id = payload.tag_id + tag = TagService.update_tags(params, tag_id) binding_count = TagService.get_tag_binding_count(tag_id) @@ -609,7 +523,7 @@ class DatasetTagsApi(DatasetApiResource): return response, 200 - @service_api_ns.expect(tag_delete_parser) + @service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__]) @service_api_ns.doc("delete_dataset_tag") @service_api_ns.doc(description="Delete a knowledge type tag") @service_api_ns.doc( @@ -623,15 +537,15 @@ class DatasetTagsApi(DatasetApiResource): @edit_permission_required def delete(self, _, dataset_id): """Delete a knowledge type tag.""" - args = tag_delete_parser.parse_args() - TagService.delete_tag(args["tag_id"]) + payload = TagDeletePayload.model_validate(service_api_ns.payload or {}) + TagService.delete_tag(payload.tag_id) return 204 @service_api_ns.route("/datasets/tags/binding") class DatasetTagBindingApi(DatasetApiResource): - @service_api_ns.expect(tag_binding_parser) + @service_api_ns.expect(service_api_ns.models[TagBindingPayload.__name__]) @service_api_ns.doc("bind_dataset_tags") @service_api_ns.doc(description="Bind tags to a dataset") @service_api_ns.doc( @@ -648,16 +562,15 @@ class DatasetTagBindingApi(DatasetApiResource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = tag_binding_parser.parse_args() - args["type"] = "knowledge" - TagService.save_tag_binding(args) + payload = TagBindingPayload.model_validate(service_api_ns.payload or {}) + TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"}) return 204 @service_api_ns.route("/datasets/tags/unbinding") class DatasetTagUnbindingApi(DatasetApiResource): - @service_api_ns.expect(tag_unbinding_parser) + @service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__]) @service_api_ns.doc("unbind_dataset_tag") @service_api_ns.doc(description="Unbind a tag from a dataset") @service_api_ns.doc( @@ -674,9 +587,8 @@ class DatasetTagUnbindingApi(DatasetApiResource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = tag_unbinding_parser.parse_args() - args["type"] = "knowledge" - TagService.delete_tag_binding(args) + payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {}) + TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"}) return 204 diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index ed47e706b6..c800c0e4e1 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -3,8 +3,8 @@ from typing import Self from uuid import UUID from flask import request -from flask_restx import marshal, reqparse -from pydantic import BaseModel, model_validator +from flask_restx import marshal +from pydantic import BaseModel, Field, model_validator from sqlalchemy import desc, select from werkzeug.exceptions import Forbidden, NotFound @@ -37,22 +37,19 @@ from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel from services.file_service import FileService -# Define parsers for document operations -document_text_create_parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=True, nullable=False, location="json") - .add_argument("text", type=str, required=True, nullable=False, location="json") - .add_argument("process_rule", type=dict, required=False, nullable=True, location="json") - .add_argument("original_document_id", type=str, required=False, location="json") - .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") - .add_argument( - "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" - ) - .add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") - .add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") -) + +class DocumentTextCreatePayload(BaseModel): + name: str + text: str + process_rule: ProcessRule | None = None + original_document_id: str | None = None + doc_form: str = Field(default="text_model") + doc_language: str = Field(default="English") + indexing_technique: str | None = None + retrieval_model: RetrievalModel | None = None + embedding_model: str | None = None + embedding_model_provider: str | None = None + DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -72,7 +69,7 @@ class DocumentTextUpdate(BaseModel): return self -for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]: +for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]: service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore @@ -83,7 +80,7 @@ for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]: class DocumentAddByTextApi(DatasetApiResource): """Resource for documents.""" - @service_api_ns.expect(document_text_create_parser) + @service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__]) @service_api_ns.doc("create_document_by_text") @service_api_ns.doc(description="Create a new document by providing text content") @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @@ -99,7 +96,8 @@ class DocumentAddByTextApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by text.""" - args = document_text_create_parser.parse_args() + payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {}) + args = payload.model_dump(exclude_none=True) dataset_id = str(dataset_id) tenant_id = str(tenant_id) @@ -111,33 +109,29 @@ class DocumentAddByTextApi(DatasetApiResource): if not dataset.indexing_technique and not args["indexing_technique"]: raise ValueError("indexing_technique is required.") - text = args.get("text") - name = args.get("name") - if text is None or name is None: - raise ValueError("Both 'text' and 'name' must be non-null values.") - - embedding_model_provider = args.get("embedding_model_provider") - embedding_model = args.get("embedding_model") + embedding_model_provider = payload.embedding_model_provider + embedding_model = payload.embedding_model if embedding_model_provider and embedding_model: DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) - retrieval_model = args.get("retrieval_model") + retrieval_model = payload.retrieval_model if ( retrieval_model - and retrieval_model.get("reranking_model") - and retrieval_model.get("reranking_model").get("reranking_provider_name") + and retrieval_model.reranking_model + and retrieval_model.reranking_model.reranking_provider_name + and retrieval_model.reranking_model.reranking_model_name ): DatasetService.check_reranking_model_setting( tenant_id, - retrieval_model.get("reranking_model").get("reranking_provider_name"), - retrieval_model.get("reranking_model").get("reranking_model_name"), + retrieval_model.reranking_model.reranking_provider_name, + retrieval_model.reranking_model.reranking_model_name, ) if not current_user: raise ValueError("current_user is required") upload_file = FileService(db.engine).upload_text( - text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id + text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id ) data_source = { "type": "upload_file", @@ -174,7 +168,7 @@ class DocumentAddByTextApi(DatasetApiResource): class DocumentUpdateByTextApi(DatasetApiResource): """Resource for update documents.""" - @service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True) + @service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__]) @service_api_ns.doc("update_document_by_text") @service_api_ns.doc(description="Update an existing document by providing text content") @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @@ -189,22 +183,23 @@ class DocumentUpdateByTextApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): """Update document by text.""" - args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True) + payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {}) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first() - + args = payload.model_dump(exclude_none=True) if not dataset: raise ValueError("Dataset does not exist.") - retrieval_model = args.get("retrieval_model") + retrieval_model = payload.retrieval_model if ( retrieval_model - and retrieval_model.get("reranking_model") - and retrieval_model.get("reranking_model").get("reranking_provider_name") + and retrieval_model.reranking_model + and retrieval_model.reranking_model.reranking_provider_name + and retrieval_model.reranking_model.reranking_model_name ): DatasetService.check_reranking_model_setting( tenant_id, - retrieval_model.get("reranking_model").get("reranking_provider_name"), - retrieval_model.get("reranking_model").get("reranking_model_name"), + retrieval_model.reranking_model.reranking_provider_name, + retrieval_model.reranking_model.reranking_model_name, ) # indexing_technique is already set in dataset since this is an update diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index f646f1f4fa..aab25c1af3 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -1,9 +1,11 @@ from typing import Literal from flask_login import current_user -from flask_restx import marshal, reqparse +from flask_restx import marshal +from pydantic import BaseModel from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_model, register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check from fields.dataset_fields import dataset_metadata_fields @@ -14,25 +16,18 @@ from services.entities.knowledge_entities.knowledge_entities import ( ) from services.metadata_service import MetadataService -# Define parsers for metadata APIs -metadata_create_parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=False, location="json", help="Metadata type") - .add_argument("name", type=str, required=True, nullable=False, location="json", help="Metadata name") -) -metadata_update_parser = reqparse.RequestParser().add_argument( - "name", type=str, required=True, nullable=False, location="json", help="New metadata name" -) +class MetadataUpdatePayload(BaseModel): + name: str -document_metadata_parser = reqparse.RequestParser().add_argument( - "operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data" -) + +register_schema_model(service_api_ns, MetadataUpdatePayload) +register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData) @service_api_ns.route("/datasets//metadata") class DatasetMetadataCreateServiceApi(DatasetApiResource): - @service_api_ns.expect(metadata_create_parser) + @service_api_ns.expect(service_api_ns.models[MetadataArgs.__name__]) @service_api_ns.doc("create_dataset_metadata") @service_api_ns.doc(description="Create metadata for a dataset") @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @@ -46,8 +41,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create metadata for a dataset.""" - args = metadata_create_parser.parse_args() - metadata_args = MetadataArgs.model_validate(args) + metadata_args = MetadataArgs.model_validate(service_api_ns.payload or {}) dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -79,7 +73,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): @service_api_ns.route("/datasets//metadata/") class DatasetMetadataServiceApi(DatasetApiResource): - @service_api_ns.expect(metadata_update_parser) + @service_api_ns.expect(service_api_ns.models[MetadataUpdatePayload.__name__]) @service_api_ns.doc("update_dataset_metadata") @service_api_ns.doc(description="Update metadata name") @service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"}) @@ -93,7 +87,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, tenant_id, dataset_id, metadata_id): """Update metadata name.""" - args = metadata_update_parser.parse_args() + payload = MetadataUpdatePayload.model_validate(service_api_ns.payload or {}) dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) @@ -102,7 +96,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"]) + metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name) return marshal(metadata, dataset_metadata_fields), 200 @service_api_ns.doc("delete_dataset_metadata") @@ -175,7 +169,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): @service_api_ns.route("/datasets//documents/metadata") class DocumentMetadataEditServiceApi(DatasetApiResource): - @service_api_ns.expect(document_metadata_parser) + @service_api_ns.expect(service_api_ns.models[MetadataOperationData.__name__]) @service_api_ns.doc("update_documents_metadata") @service_api_ns.doc(description="Update metadata for multiple documents") @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @@ -195,8 +189,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - args = document_metadata_parser.parse_args() - metadata_args = MetadataOperationData.model_validate(args) + metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {}) MetadataService.update_documents_metadata(dataset, metadata_args) diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index c177e9180a..0a2017e2bd 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -4,12 +4,12 @@ from collections.abc import Generator from typing import Any from flask import request -from flask_restx import reqparse -from flask_restx.reqparse import ParseResult, RequestParser +from pydantic import BaseModel from werkzeug.exceptions import Forbidden import services from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError +from controllers.common.schema import register_schema_model from controllers.service_api import service_api_ns from controllers.service_api.dataset.error import PipelineRunError from controllers.service_api.wraps import DatasetApiResource @@ -22,11 +22,25 @@ from models.dataset import Pipeline from models.engine import db from services.errors.file import FileTooLargeError, UnsupportedFileTypeError from services.file_service import FileService -from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity +from services.rag_pipeline.entity.pipeline_service_api_entities import ( + DatasourceNodeRunApiEntity, + PipelineRunApiEntity, +) from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService from services.rag_pipeline.rag_pipeline import RagPipelineService +class DatasourceNodeRunPayload(BaseModel): + inputs: dict[str, Any] + datasource_type: str + credential_id: str | None = None + is_published: bool + + +register_schema_model(service_api_ns, DatasourceNodeRunPayload) +register_schema_model(service_api_ns, PipelineRunApiEntity) + + @service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins") class DatasourcePluginsApi(DatasetApiResource): """Resource for datasource plugins.""" @@ -88,22 +102,20 @@ class DatasourceNodeRunApi(DatasetApiResource): 401: "Unauthorized - invalid API token", } ) + @service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__]) def post(self, tenant_id: str, dataset_id: str, node_id: str): """Resource for getting datasource plugins.""" - # Get query parameter to determine published or draft - parser: RequestParser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("datasource_type", type=str, required=True, location="json") - .add_argument("credential_id", type=str, required=False, location="json") - .add_argument("is_published", type=bool, required=True, location="json") - ) - args: ParseResult = parser.parse_args() - - datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args) + payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {}) assert isinstance(current_user, Account) rag_pipeline_service: RagPipelineService = RagPipelineService() pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id) + datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate( + { + **payload.model_dump(exclude_none=True), + "pipeline_id": str(pipeline.id), + "node_id": node_id, + } + ) return helper.compact_generate_response( PipelineGenerator.convert_to_event_stream( rag_pipeline_service.run_datasource_workflow_node( @@ -147,25 +159,10 @@ class PipelineRunApi(DatasetApiResource): 401: "Unauthorized - invalid API token", } ) + @service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__]) def post(self, tenant_id: str, dataset_id: str): """Resource for running a rag pipeline.""" - parser: RequestParser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("datasource_type", type=str, required=True, location="json") - .add_argument("datasource_info_list", type=list, required=True, location="json") - .add_argument("start_node_id", type=str, required=True, location="json") - .add_argument("is_published", type=bool, required=True, default=True, location="json") - .add_argument( - "response_mode", - type=str, - required=True, - choices=["streaming", "blocking"], - default="blocking", - location="json", - ) - ) - args: ParseResult = parser.parse_args() + payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {}) if not isinstance(current_user, Account): raise Forbidden() @@ -176,9 +173,9 @@ class PipelineRunApi(DatasetApiResource): response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate( pipeline=pipeline, user=current_user, - args=args, - invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER, - streaming=args.get("response_mode") == "streaming", + args=payload.model_dump(), + invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER, + streaming=payload.response_mode == "streaming", ) return helper.compact_generate_response(response) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 9ca500b044..b242fd2c3e 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,8 +1,12 @@ +from typing import Any + from flask import request -from flask_restx import marshal, reqparse +from flask_restx import marshal +from pydantic import BaseModel, Field from werkzeug.exceptions import NotFound from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.wraps import ( @@ -24,34 +28,42 @@ from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexing from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError -# Define parsers for segment operations -segment_create_parser = reqparse.RequestParser().add_argument( - "segments", type=list, required=False, nullable=True, location="json" -) -segment_list_parser = ( - reqparse.RequestParser() - .add_argument("status", type=str, action="append", default=[], location="args") - .add_argument("keyword", type=str, default=None, location="args") -) +class SegmentCreatePayload(BaseModel): + segments: list[dict[str, Any]] | None = None -segment_update_parser = reqparse.RequestParser().add_argument( - "segment", type=dict, required=False, nullable=True, location="json" -) -child_chunk_create_parser = reqparse.RequestParser().add_argument( - "content", type=str, required=True, nullable=False, location="json" -) +class SegmentListQuery(BaseModel): + status: list[str] = Field(default_factory=list) + keyword: str | None = None -child_chunk_list_parser = ( - reqparse.RequestParser() - .add_argument("limit", type=int, default=20, location="args") - .add_argument("keyword", type=str, default=None, location="args") - .add_argument("page", type=int, default=1, location="args") -) -child_chunk_update_parser = reqparse.RequestParser().add_argument( - "content", type=str, required=True, nullable=False, location="json" +class SegmentUpdatePayload(BaseModel): + segment: SegmentUpdateArgs + + +class ChildChunkCreatePayload(BaseModel): + content: str + + +class ChildChunkListQuery(BaseModel): + limit: int = Field(default=20, ge=1) + keyword: str | None = None + page: int = Field(default=1, ge=1) + + +class ChildChunkUpdatePayload(BaseModel): + content: str + + +register_schema_models( + service_api_ns, + SegmentCreatePayload, + SegmentListQuery, + SegmentUpdatePayload, + ChildChunkCreatePayload, + ChildChunkListQuery, + ChildChunkUpdatePayload, ) @@ -59,7 +71,7 @@ child_chunk_update_parser = reqparse.RequestParser().add_argument( class SegmentApi(DatasetApiResource): """Resource for segments.""" - @service_api_ns.expect(segment_create_parser) + @service_api_ns.expect(service_api_ns.models[SegmentCreatePayload.__name__]) @service_api_ns.doc("create_segments") @service_api_ns.doc(description="Create segments in a document") @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @@ -106,20 +118,20 @@ class SegmentApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # validate args - args = segment_create_parser.parse_args() - if args["segments"] is not None: + payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {}) + if payload.segments is not None: segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST - if segments_limit > 0 and len(args["segments"]) > segments_limit: + if segments_limit > 0 and len(payload.segments) > segments_limit: raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.") - for args_item in args["segments"]: + for args_item in payload.segments: SegmentService.segment_create_args_validate(args_item, document) - segments = SegmentService.multi_create_segment(args["segments"], document, dataset) + segments = SegmentService.multi_create_segment(payload.segments, document, dataset) return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200 else: return {"error": "Segments is required"}, 400 - @service_api_ns.expect(segment_list_parser) + @service_api_ns.expect(service_api_ns.models[SegmentListQuery.__name__]) @service_api_ns.doc("list_segments") @service_api_ns.doc(description="List segments in a document") @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @@ -160,13 +172,18 @@ class SegmentApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) - args = segment_list_parser.parse_args() + args = SegmentListQuery.model_validate( + { + "status": request.args.getlist("status"), + "keyword": request.args.get("keyword"), + } + ) segments, total = SegmentService.get_segments( document_id=document_id, tenant_id=current_tenant_id, - status_list=args["status"], - keyword=args["keyword"], + status_list=args.status, + keyword=args.keyword, page=page, limit=limit, ) @@ -217,7 +234,7 @@ class DatasetSegmentApi(DatasetApiResource): SegmentService.delete_segment(segment, document, dataset) return 204 - @service_api_ns.expect(segment_update_parser) + @service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__]) @service_api_ns.doc("update_segment") @service_api_ns.doc(description="Update a specific segment") @service_api_ns.doc( @@ -265,12 +282,9 @@ class DatasetSegmentApi(DatasetApiResource): if not segment: raise NotFound("Segment not found.") - # validate args - args = segment_update_parser.parse_args() + payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {}) - updated_segment = SegmentService.update_segment( - SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset - ) + updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset) return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 @service_api_ns.doc("get_segment") @@ -308,7 +322,7 @@ class DatasetSegmentApi(DatasetApiResource): class ChildChunkApi(DatasetApiResource): """Resource for child chunks.""" - @service_api_ns.expect(child_chunk_create_parser) + @service_api_ns.expect(service_api_ns.models[ChildChunkCreatePayload.__name__]) @service_api_ns.doc("create_child_chunk") @service_api_ns.doc(description="Create a new child chunk for a segment") @service_api_ns.doc( @@ -360,16 +374,16 @@ class ChildChunkApi(DatasetApiResource): raise ProviderNotInitializeError(ex.description) # validate args - args = child_chunk_create_parser.parse_args() + payload = ChildChunkCreatePayload.model_validate(service_api_ns.payload or {}) try: - child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset) + child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 - @service_api_ns.expect(child_chunk_list_parser) + @service_api_ns.expect(service_api_ns.models[ChildChunkListQuery.__name__]) @service_api_ns.doc("list_child_chunks") @service_api_ns.doc(description="List child chunks for a segment") @service_api_ns.doc( @@ -400,11 +414,17 @@ class ChildChunkApi(DatasetApiResource): if not segment: raise NotFound("Segment not found.") - args = child_chunk_list_parser.parse_args() + args = ChildChunkListQuery.model_validate( + { + "limit": request.args.get("limit", default=20, type=int), + "keyword": request.args.get("keyword"), + "page": request.args.get("page", default=1, type=int), + } + ) - page = args["page"] - limit = min(args["limit"], 100) - keyword = args["keyword"] + page = args.page + limit = min(args.limit, 100) + keyword = args.keyword child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) @@ -480,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource): return 204 - @service_api_ns.expect(child_chunk_update_parser) + @service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__]) @service_api_ns.doc("update_child_chunk") @service_api_ns.doc(description="Update a specific child chunk") @service_api_ns.doc( @@ -533,10 +553,10 @@ class DatasetChildChunkApi(DatasetApiResource): raise NotFound("Child chunk not found.") # validate args - args = child_chunk_update_parser.parse_args() + payload = ChildChunkUpdatePayload.model_validate(service_api_ns.payload or {}) try: - child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset) + child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index cdbd2355ca..dfb49cf2bd 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -101,8 +101,8 @@ class HitTestingService: dataset: Dataset, query: str, account: Account, - external_retrieval_model: dict, - metadata_filtering_conditions: dict, + external_retrieval_model: dict | None = None, + metadata_filtering_conditions: dict | None = None, ): if dataset.provider != "external": return { diff --git a/api/tests/unit_tests/controllers/console/billing/test_billing.py b/api/tests/unit_tests/controllers/console/billing/test_billing.py index eaa489d56b..c80758c857 100644 --- a/api/tests/unit_tests/controllers/console/billing/test_billing.py +++ b/api/tests/unit_tests/controllers/console/billing/test_billing.py @@ -125,7 +125,7 @@ class TestPartnerTenants: resource = PartnerTenants() # Act & Assert - # reqparse will raise BadRequest for missing required field + # Validation should raise BadRequest for missing required field with pytest.raises(BadRequest): resource.put(partner_key_encoded) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py index 5c484403a6..acff191c79 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py @@ -256,24 +256,18 @@ class TestFilePreviewApi: mock_app, # App query for tenant validation ] - with patch("controllers.service_api.app.file_preview.reqparse") as mock_reqparse: - # Mock request parsing - mock_parser = Mock() - mock_parser.parse_args.return_value = {"as_attachment": False} - mock_reqparse.RequestParser.return_value = mock_parser + # Test the core logic directly without Flask decorators + # Validate file ownership + result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) + assert result_message_file == mock_message_file + assert result_upload_file == mock_upload_file - # Test the core logic directly without Flask decorators - # Validate file ownership - result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) - assert result_message_file == mock_message_file - assert result_upload_file == mock_upload_file + # Test file response building + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + assert response is not None - # Test file response building - response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) - assert response is not None - - # Verify storage was called correctly - mock_storage.load.assert_not_called() # Since we're testing components separately + # Verify storage was called correctly + mock_storage.load.assert_not_called() # Since we're testing components separately @patch("controllers.service_api.app.file_preview.storage") def test_storage_error_handling( diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index bbfa9da15e..fc3a2fc416 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -2,8 +2,6 @@ from pathlib import Path from unittest.mock import Mock, create_autospec, patch import pytest -from flask_restx import reqparse -from werkzeug.exceptions import BadRequest from models.account import Account from services.entities.knowledge_entities.knowledge_entities import MetadataArgs @@ -77,60 +75,39 @@ class TestMetadataBugCompleteValidation: assert type_column.nullable is False, "type column should be nullable=False" assert name_column.nullable is False, "name column should be nullable=False" - def test_4_fixed_api_layer_rejects_null(self, app): - """Test Layer 4: Fixed API configuration properly rejects null values.""" - # Test Console API create endpoint (fixed) - parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - ) + def test_4_fixed_api_layer_rejects_null(self): + """Test Layer 4: Fixed API configuration properly rejects null values using Pydantic.""" + with pytest.raises((ValueError, TypeError)): + MetadataArgs.model_validate({"type": None, "name": None}) - with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): - with pytest.raises(BadRequest): - parser.parse_args() + with pytest.raises((ValueError, TypeError)): + MetadataArgs.model_validate({"type": "string", "name": None}) - # Test with just name being null - with app.test_request_context(json={"type": "string", "name": None}, content_type="application/json"): - with pytest.raises(BadRequest): - parser.parse_args() + with pytest.raises((ValueError, TypeError)): + MetadataArgs.model_validate({"type": None, "name": "test"}) - # Test with just type being null - with app.test_request_context(json={"type": None, "name": "test"}, content_type="application/json"): - with pytest.raises(BadRequest): - parser.parse_args() - - def test_5_fixed_api_accepts_valid_values(self, app): + def test_5_fixed_api_accepts_valid_values(self): """Test that fixed API still accepts valid non-null values.""" - parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - ) + args = MetadataArgs.model_validate({"type": "string", "name": "valid_name"}) + assert args.type == "string" + assert args.name == "valid_name" - with app.test_request_context(json={"type": "string", "name": "valid_name"}, content_type="application/json"): - args = parser.parse_args() - assert args["type"] == "string" - assert args["name"] == "valid_name" + def test_6_simulated_buggy_behavior(self): + """Test simulating the original buggy behavior by bypassing Pydantic validation.""" + mock_metadata_args = Mock() + mock_metadata_args.name = None + mock_metadata_args.type = None - def test_6_simulated_buggy_behavior(self, app): - """Test simulating the original buggy behavior with nullable=True.""" - # Simulate the old buggy configuration - buggy_parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=True, location="json") - .add_argument("name", type=str, required=True, nullable=True, location="json") - ) + mock_user = create_autospec(Account, instance=True) + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" - with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): - # This would pass in the buggy version - args = buggy_parser.parse_args() - assert args["type"] is None - assert args["name"] is None - - # But would crash when trying to create MetadataArgs - with pytest.raises((ValueError, TypeError)): - MetadataArgs.model_validate(args) + with patch( + "services.metadata_service.current_account_with_tenant", + return_value=(mock_user, mock_user.current_tenant_id), + ): + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.create_metadata("dataset-123", mock_metadata_args) def test_7_end_to_end_validation_layers(self): """Test all validation layers work together correctly.""" diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py index c8a1a70422..f43f394489 100644 --- a/api/tests/unit_tests/services/test_metadata_nullable_bug.py +++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py @@ -1,7 +1,6 @@ from unittest.mock import Mock, create_autospec, patch import pytest -from flask_restx import reqparse from models.account import Account from services.entities.knowledge_entities.knowledge_entities import MetadataArgs @@ -51,76 +50,16 @@ class TestMetadataNullableBug: with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.update_metadata_name("dataset-123", "metadata-456", None) - def test_api_parser_accepts_null_values(self, app): - """Test that API parser configuration incorrectly accepts null values.""" - # Simulate the current API parser configuration - parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=True, location="json") - .add_argument("name", type=str, required=True, nullable=True, location="json") - ) + def test_api_layer_now_uses_pydantic_validation(self): + """Verify that API layer relies on Pydantic validation instead of reqparse.""" + invalid_payload = {"type": None, "name": None} + with pytest.raises((ValueError, TypeError)): + MetadataArgs.model_validate(invalid_payload) - # Simulate request data with null values - with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): - # This should parse successfully due to nullable=True - args = parser.parse_args() - - # Verify that null values are accepted - assert args["type"] is None - assert args["name"] is None - - # This demonstrates the bug: API accepts None but business logic will crash - - def test_integration_bug_scenario(self, app): - """Test the complete bug scenario from API to service layer.""" - # Step 1: API parser accepts null values (current buggy behavior) - parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=True, location="json") - .add_argument("name", type=str, required=True, nullable=True, location="json") - ) - - with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): - args = parser.parse_args() - - # Step 2: Try to create MetadataArgs with None values - # This should fail at Pydantic validation level - with pytest.raises((ValueError, TypeError)): - metadata_args = MetadataArgs.model_validate(args) - - # Step 3: If we bypass Pydantic (simulating the bug scenario) - # Move this outside the request context to avoid Flask-Login issues - mock_metadata_args = Mock() - mock_metadata_args.name = None # From args["name"] - mock_metadata_args.type = None # From args["type"] - - mock_user = create_autospec(Account, instance=True) - mock_user.current_tenant_id = "tenant-123" - mock_user.id = "user-456" - - with patch( - "services.metadata_service.current_account_with_tenant", - return_value=(mock_user, mock_user.current_tenant_id), - ): - # Step 4: Service layer crashes on len(None) - with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.create_metadata("dataset-123", mock_metadata_args) - - def test_correct_nullable_false_configuration_works(self, app): - """Test that the correct nullable=False configuration works as expected.""" - # This tests the FIXED configuration - parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - ) - - with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): - # This should fail with BadRequest due to nullable=False - from werkzeug.exceptions import BadRequest - - with pytest.raises(BadRequest): - parser.parse_args() + valid_payload = {"type": "string", "name": "valid"} + args = MetadataArgs.model_validate(valid_payload) + assert args.type == "string" + assert args.name == "valid" if __name__ == "__main__":