refactor: port reqparse to BaseModel (#28993)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-12-08 15:31:19 +09:00
committed by GitHub
parent 2f96374837
commit 05fe92a541
44 changed files with 1531 additions and 1894 deletions

View File

@@ -0,0 +1,26 @@
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
from flask_restx import Namespace
from pydantic import BaseModel
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
"""Register a single BaseModel with a namespace for Swagger documentation."""
namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
"""Register multiple BaseModels with a namespace."""
for model in models:
register_schema_model(namespace, model)
__all__ = [
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
"register_schema_model",
"register_schema_models",
]

View File

@@ -31,7 +31,6 @@ from fields.app_fields import (
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
from libs.helper import AppIconUrlField, TimestampField from libs.helper import AppIconUrlField, TimestampField
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import App, Workflow from models import App, Workflow
from services.app_dsl_service import AppDslService, ImportMode from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService from services.app_service import AppService
@@ -76,51 +75,30 @@ class AppListQuery(BaseModel):
class CreateAppPayload(BaseModel): class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name") name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode") mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: str | None = Field(default=None, description="Icon type") icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon") icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color") icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class UpdateAppPayload(BaseModel): class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name") name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type") icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon") icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color") icon_background: str | None = Field(default=None, description="Icon background color")
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon") use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests") max_active_requests: int | None = Field(default=None, description="Maximum active requests")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class CopyAppPayload(BaseModel): class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app") name: str | None = Field(default=None, description="Name for the copied app")
description: str | None = Field(default=None, description="Description for the copied app") description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type") icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon") icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color") icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class AppExportQuery(BaseModel): class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export") include_secret: bool = Field(default=False, description="Include secrets in export")

View File

@@ -1,15 +1,15 @@
import json import json
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import Any, cast
from flask import request from flask import request
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import console_ns from controllers.common.schema import register_schema_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
@@ -25,6 +25,19 @@ from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task from tasks.document_indexing_sync_task import document_indexing_sync_task
from .. import console_ns
from ..wraps import account_initialization_required, setup_required
class NotionEstimatePayload(BaseModel):
notion_info_list: list[dict[str, Any]]
process_rule: dict[str, Any]
doc_form: str = Field(default="text_model")
doc_language: str = Field(default="English")
register_schema_model(console_ns, NotionEstimatePayload)
@console_ns.route( @console_ns.route(
"/data-source/integrates", "/data-source/integrates",
@@ -243,20 +256,15 @@ class DataSourceNotionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
def post(self): def post(self):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
parser = ( payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump()
.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
notion_info_list = args["notion_info_list"] notion_info_list = payload.notion_info_list
extract_settings = [] extract_settings = []
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"] workspace_id = notion_info["workspace_id"]

View File

@@ -1,12 +1,14 @@
from typing import Any, cast from typing import Any, cast
from flask import request from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from configs import dify_config from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.apikey import ( from controllers.console.apikey import (
api_key_item_model, api_key_item_model,
@@ -48,7 +50,6 @@ from fields.dataset_fields import (
) )
from fields.document_fields import document_status_fields from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
@@ -107,10 +108,74 @@ related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_mode
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy) related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_name(name: str) -> str: def _validate_indexing_technique(value: str | None) -> str | None:
if not name or len(name) < 1 or len(name) > 40: if value is None:
raise ValueError("Name must be between 1 to 40 characters.") return value
return name if value not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Invalid indexing technique.")
return value
class DatasetCreatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field("", max_length=400)
indexing_technique: str | None = None
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
provider: str = "vendor"
external_knowledge_api_id: str | None = None
external_knowledge_id: str | None = None
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str | None) -> str | None:
return _validate_indexing_technique(value)
@field_validator("provider")
@classmethod
def validate_provider(cls, value: str) -> str:
if value not in Dataset.PROVIDER_LIST:
raise ValueError("Invalid provider.")
return value
class DatasetUpdatePayload(BaseModel):
name: str | None = Field(None, min_length=1, max_length=40)
description: str | None = Field(None, max_length=400)
permission: DatasetPermissionEnum | None = None
indexing_technique: str | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
partial_member_list: list[str] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
icon_info: dict[str, Any] | None = None
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str | None) -> str | None:
return _validate_indexing_technique(value)
class IndexingEstimatePayload(BaseModel):
info_list: dict[str, Any]
process_rule: dict[str, Any]
indexing_technique: str
doc_form: str = "text_model"
dataset_id: str | None = None
doc_language: str = "English"
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str) -> str:
result = _validate_indexing_technique(value)
if result is None:
raise ValueError("indexing_technique is required.")
return result
register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]: def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
@@ -255,20 +320,7 @@ class DatasetListApi(Resource):
@console_ns.doc("create_dataset") @console_ns.doc("create_dataset")
@console_ns.doc(description="Create a new dataset") @console_ns.doc(description="Create a new dataset")
@console_ns.expect( @console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
console_ns.model(
"CreateDatasetRequest",
{
"name": fields.String(required=True, description="Dataset name (1-40 characters)"),
"description": fields.String(description="Dataset description (max 400 characters)"),
"indexing_technique": fields.String(description="Indexing technique"),
"permission": fields.String(description="Dataset permission"),
"provider": fields.String(description="Provider"),
"external_knowledge_api_id": fields.String(description="External knowledge API ID"),
"external_knowledge_id": fields.String(description="External knowledge ID"),
},
)
)
@console_ns.response(201, "Dataset created successfully") @console_ns.response(201, "Dataset created successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@@ -276,52 +328,7 @@ class DatasetListApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
parser = ( payload = DatasetCreatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
)
.add_argument(
"provider",
type=str,
nullable=True,
choices=Dataset.PROVIDER_LIST,
required=False,
default="vendor",
)
.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
)
args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@@ -331,14 +338,14 @@ class DatasetListApi(Resource):
try: try:
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
name=args["name"], name=payload.name,
description=args["description"], description=payload.description,
indexing_technique=args["indexing_technique"], indexing_technique=payload.indexing_technique,
account=current_user, account=current_user,
permission=DatasetPermissionEnum.ONLY_ME, permission=payload.permission or DatasetPermissionEnum.ONLY_ME,
provider=args["provider"], provider=payload.provider,
external_knowledge_api_id=args["external_knowledge_api_id"], external_knowledge_api_id=payload.external_knowledge_api_id,
external_knowledge_id=args["external_knowledge_id"], external_knowledge_id=payload.external_knowledge_id,
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@@ -399,18 +406,7 @@ class DatasetApi(Resource):
@console_ns.doc("update_dataset") @console_ns.doc("update_dataset")
@console_ns.doc(description="Update dataset details") @console_ns.doc(description="Update dataset details")
@console_ns.expect( @console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
console_ns.model(
"UpdateDatasetRequest",
{
"name": fields.String(description="Dataset name"),
"description": fields.String(description="Dataset description"),
"permission": fields.String(description="Dataset permission"),
"indexing_technique": fields.String(description="Indexing technique"),
"external_retrieval_model": fields.Raw(description="External retrieval model settings"),
},
)
)
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model) @console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
@console_ns.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@@ -424,93 +420,26 @@ class DatasetApi(Resource):
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
parser = ( payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() payload_data = payload.model_dump(exclude_unset=True)
.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(
DatasetPermissionEnum.ONLY_ME,
DatasetPermissionEnum.ALL_TEAM,
DatasetPermissionEnum.PARTIAL_TEAM,
),
help="Invalid permission.",
)
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
.add_argument(
"icon_info",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid icon info.",
)
)
args = parser.parse_args()
data = request.get_json()
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting # check embedding model setting
if ( if (
data.get("indexing_technique") == "high_quality" payload.indexing_technique == "high_quality"
and data.get("embedding_model_provider") is not None and payload.embedding_model_provider is not None
and data.get("embedding_model") is not None and payload.embedding_model is not None
): ):
DatasetService.check_embedding_model_setting( DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
) )
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission( DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list") current_user, dataset, payload.permission, payload.partial_member_list
) )
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@@ -518,15 +447,10 @@ class DatasetApi(Resource):
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_tenant_id tenant_id = current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members": if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
DatasetPermissionService.update_partial_member_list( DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
tenant_id, dataset_id_str, data.get("partial_member_list")
)
# clear partial member list when permission is only_me or all_team_members # clear partial member list when permission is only_me or all_team_members
elif ( elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
DatasetPermissionService.clear_partial_member_list(dataset_id_str) DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@@ -615,24 +539,10 @@ class DatasetIndexingEstimateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__])
def post(self): def post(self):
parser = ( payload = IndexingEstimatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump()
.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument(
"indexing_technique",
type=str,
required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
location="json",
)
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)

View File

@@ -6,31 +6,14 @@ from typing import Literal, cast
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from sqlalchemy import asc, desc, select from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import (
ArchivedDocumentImmutableError,
DocumentAlreadyFinishedError,
DocumentIndexingError,
IndexingEstimateError,
InvalidActionError,
InvalidMetadataError,
)
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
from core.errors.error import ( from core.errors.error import (
LLMBadRequestError, LLMBadRequestError,
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
@@ -55,10 +38,30 @@ from fields.document_fields import (
) )
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from ..datasets.error import (
ArchivedDocumentImmutableError,
DocumentAlreadyFinishedError,
DocumentIndexingError,
IndexingEstimateError,
InvalidActionError,
InvalidMetadataError,
)
from ..wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -93,6 +96,24 @@ dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(docume
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
class DocumentRetryPayload(BaseModel):
document_ids: list[str]
class DocumentRenamePayload(BaseModel):
name: str
register_schema_models(
console_ns,
KnowledgeConfig,
ProcessRule,
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
)
class DocumentResource(Resource): class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document: def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@@ -311,6 +332,7 @@ class DatasetDocumentListApi(Resource):
@marshal_with(dataset_and_document_model) @marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@@ -329,23 +351,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
parser = ( knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
)
.add_argument("data_source", type=dict, required=False, location="json")
.add_argument("process_rule", type=dict, required=False, location="json")
.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
.add_argument("original_document_id", type=str, required=False, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args)
if not dataset.indexing_technique and not knowledge_config.indexing_technique: if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
@@ -391,17 +397,7 @@ class DatasetDocumentListApi(Resource):
class DatasetInitApi(Resource): class DatasetInitApi(Resource):
@console_ns.doc("init_dataset") @console_ns.doc("init_dataset")
@console_ns.doc(description="Initialize dataset with documents") @console_ns.doc(description="Initialize dataset with documents")
@console_ns.expect( @console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
console_ns.model(
"DatasetInitRequest",
{
"upload_file_id": fields.String(required=True, description="Upload file ID"),
"indexing_technique": fields.String(description="Indexing technique"),
"process_rule": fields.Raw(description="Processing rules"),
"data_source": fields.Raw(description="Data source configuration"),
},
)
)
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model) @console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@@ -416,27 +412,7 @@ class DatasetInitApi(Resource):
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
parser = ( knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"indexing_technique",
type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
required=True,
nullable=False,
location="json",
)
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args)
if knowledge_config.indexing_technique == "high_quality": if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.") raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
@@ -444,9 +420,9 @@ class DatasetInitApi(Resource):
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
provider=args["embedding_model_provider"], provider=knowledge_config.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=args["embedding_model"], model=knowledge_config.embedding_model,
) )
except InvokeAuthorizationError: except InvokeAuthorizationError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
@@ -1077,19 +1053,16 @@ class DocumentRetryApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
def post(self, dataset_id): def post(self, dataset_id):
"""retry document.""" """retry document."""
payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
parser = reqparse.RequestParser().add_argument(
"document_ids", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
retry_documents = [] retry_documents = []
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
for document_id in args["document_ids"]: for document_id in payload.document_ids:
try: try:
document_id = str(document_id) document_id = str(document_id)
@@ -1122,6 +1095,7 @@ class DocumentRenameApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(document_fields) @marshal_with(document_fields)
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
@@ -1131,11 +1105,10 @@ class DocumentRenameApi(DocumentResource):
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(current_user, dataset) DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json") payload = DocumentRenamePayload.model_validate(console_ns.payload or {})
args = parser.parse_args()
try: try:
document = DocumentService.rename_document(dataset_id, document_id, args["name"]) document = DocumentService.rename_document(dataset_id, document_id, payload.name)
except services.errors.document.DocumentIndexingError: except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.") raise DocumentIndexingError("Cannot delete document during indexing.")

View File

@@ -1,11 +1,13 @@
import uuid import uuid
from flask import request from flask import request
from flask_restx import Resource, marshal, reqparse from flask_restx import Resource, marshal
from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import ( from controllers.console.datasets.error import (
@@ -36,6 +38,56 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
hit_count_gte: int | None = None
enabled: str = Field(default="all")
keyword: str | None = None
page: int = Field(default=1, ge=1)
class SegmentCreatePayload(BaseModel):
content: str
answer: str | None = None
keywords: list[str] | None = None
class SegmentUpdatePayload(BaseModel):
content: str
answer: str | None = None
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
class BatchImportPayload(BaseModel):
upload_file_id: str
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkUpdatePayload(BaseModel):
content: str
class ChildChunkBatchUpdatePayload(BaseModel):
chunks: list[ChildChunkUpdateArgs]
register_schema_models(
console_ns,
SegmentListQuery,
SegmentCreatePayload,
SegmentUpdatePayload,
BatchImportPayload,
ChildChunkCreatePayload,
ChildChunkUpdatePayload,
ChildChunkBatchUpdatePayload,
)
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments") @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
class DatasetDocumentSegmentListApi(Resource): class DatasetDocumentSegmentListApi(Resource):
@setup_required @setup_required
@@ -60,23 +112,18 @@ class DatasetDocumentSegmentListApi(Resource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
parser = ( args = SegmentListQuery.model_validate(
reqparse.RequestParser() {
.add_argument("limit", type=int, default=20, location="args") **request.args.to_dict(),
.add_argument("status", type=str, action="append", default=[], location="args") "status": request.args.getlist("status"),
.add_argument("hit_count_gte", type=int, default=None, location="args") }
.add_argument("enabled", type=str, default="all", location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
) )
args = parser.parse_args() page = args.page
limit = min(args.limit, 100)
page = args["page"] status_list = args.status
limit = min(args["limit"], 100) hit_count_gte = args.hit_count_gte
status_list = args["status"] keyword = args.keyword
hit_count_gte = args["hit_count_gte"]
keyword = args["keyword"]
query = ( query = (
select(DocumentSegment) select(DocumentSegment)
@@ -96,10 +143,10 @@ class DatasetDocumentSegmentListApi(Resource):
if keyword: if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
if args["enabled"].lower() != "all": if args.enabled.lower() != "all":
if args["enabled"].lower() == "true": if args.enabled.lower() == "true":
query = query.where(DocumentSegment.enabled == True) query = query.where(DocumentSegment.enabled == True)
elif args["enabled"].lower() == "false": elif args.enabled.lower() == "false":
query = query.where(DocumentSegment.enabled == False) query = query.where(DocumentSegment.enabled == False)
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@@ -210,6 +257,7 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@@ -246,15 +294,10 @@ class DatasetDocumentSegmentAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = ( payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() payload_dict = payload.model_dump(exclude_none=True)
.add_argument("content", type=str, required=True, nullable=False, location="json") SegmentService.segment_create_args_validate(payload_dict, document)
.add_argument("answer", type=str, required=False, nullable=True, location="json") segment = SegmentService.create_segment(payload_dict, document, dataset)
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@@ -265,6 +308,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
def patch(self, dataset_id, document_id, segment_id): def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@@ -313,18 +357,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = ( payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() payload_dict = payload.model_dump(exclude_none=True)
.add_argument("content", type=str, required=True, nullable=False, location="json") SegmentService.segment_create_args_validate(payload_dict, document)
.add_argument("answer", type=str, required=False, nullable=True, location="json") segment = SegmentService.update_segment(
.add_argument("keywords", type=list, required=False, nullable=True, location="json") SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
.add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
) )
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required @setup_required
@@ -377,6 +415,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@@ -391,11 +430,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
parser = reqparse.RequestParser().add_argument( payload = BatchImportPayload.model_validate(console_ns.payload or {})
"upload_file_id", type=str, required=True, nullable=False, location="json" upload_file_id = payload.upload_file_id
)
args = parser.parse_args()
upload_file_id = args["upload_file_id"]
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if not upload_file: if not upload_file:
@@ -446,6 +482,7 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
def post(self, dataset_id, document_id, segment_id): def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@@ -491,13 +528,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try: try:
content = args["content"] payload = ChildChunkCreatePayload.model_validate(console_ns.payload or {})
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset) child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200 return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@@ -529,18 +562,17 @@ class ChildChunkAddApi(Resource):
) )
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
parser = ( args = SegmentListQuery.model_validate(
reqparse.RequestParser() {
.add_argument("limit", type=int, default=20, location="args") "limit": request.args.get("limit", default=20, type=int),
.add_argument("keyword", type=str, default=None, location="args") "keyword": request.args.get("keyword"),
.add_argument("page", type=int, default=1, location="args") "page": request.args.get("page", default=1, type=int),
}
) )
args = parser.parse_args() page = args.page
limit = min(args.limit, 100)
page = args["page"] keyword = args.keyword
limit = min(args["limit"], 100)
keyword = args["keyword"]
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return { return {
@@ -588,14 +620,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser().add_argument( payload = ChildChunkBatchUpdatePayload.model_validate(console_ns.payload or {})
"chunks", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try: try:
chunks_data = args["chunks"] child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200 return {"data": marshal(child_chunks, child_chunk_fields)}, 200
@@ -665,6 +692,7 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
def patch(self, dataset_id, document_id, segment_id, child_chunk_id): def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@@ -711,13 +739,9 @@ class ChildChunkUpdateApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try: try:
content = args["content"] payload = ChildChunkUpdatePayload.model_validate(console_ns.payload or {})
child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset) child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200 return {"data": marshal(child_chunk, child_chunk_fields)}, 200

View File

@@ -1,8 +1,10 @@
from flask import request from flask import request
from flask_restx import Resource, fields, marshal, reqparse from flask_restx import Resource, fields, marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@@ -71,10 +73,38 @@ except KeyError:
dataset_detail_model = _build_dataset_detail_model() dataset_detail_model = _build_dataset_detail_model()
def _validate_name(name: str) -> str: class ExternalKnowledgeApiPayload(BaseModel):
if not name or len(name) < 1 or len(name) > 100: name: str = Field(..., min_length=1, max_length=40)
raise ValueError("Name must be between 1 to 100 characters.") settings: dict[str, object]
return name
class ExternalDatasetCreatePayload(BaseModel):
external_knowledge_api_id: str
external_knowledge_id: str
name: str = Field(..., min_length=1, max_length=40)
description: str | None = Field(None, max_length=400)
external_retrieval_model: dict[str, object] | None = None
class ExternalHitTestingPayload(BaseModel):
query: str
external_retrieval_model: dict[str, object] | None = None
metadata_filtering_conditions: dict[str, object] | None = None
class BedrockRetrievalPayload(BaseModel):
retrieval_setting: dict[str, object]
query: str
knowledge_id: str
register_schema_models(
console_ns,
ExternalKnowledgeApiPayload,
ExternalDatasetCreatePayload,
ExternalHitTestingPayload,
BedrockRetrievalPayload,
)
@console_ns.route("/datasets/external-knowledge-api") @console_ns.route("/datasets/external-knowledge-api")
@@ -113,28 +143,12 @@ class ExternalApiTemplateListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def post(self): def post(self):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
parser = ( payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
)
args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"]) ExternalDatasetService.validate_api_list(payload.settings)
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@@ -142,7 +156,7 @@ class ExternalApiTemplateListApi(Resource):
try: try:
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api( external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
tenant_id=current_tenant_id, user_id=current_user.id, args=args tenant_id=current_tenant_id, user_id=current_user.id, args=payload.model_dump()
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@@ -171,35 +185,19 @@ class ExternalApiTemplateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def patch(self, external_knowledge_api_id): def patch(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id) external_knowledge_api_id = str(external_knowledge_api_id)
parser = ( payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() ExternalDatasetService.validate_api_list(payload.settings)
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
)
args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"])
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api( external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
user_id=current_user.id, user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id, external_knowledge_api_id=external_knowledge_api_id,
args=args, args=payload.model_dump(),
) )
return external_knowledge_api.to_dict(), 200 return external_knowledge_api.to_dict(), 200
@@ -240,17 +238,7 @@ class ExternalApiUseCheckApi(Resource):
class ExternalDatasetCreateApi(Resource): class ExternalDatasetCreateApi(Resource):
@console_ns.doc("create_external_dataset") @console_ns.doc("create_external_dataset")
@console_ns.doc(description="Create external knowledge dataset") @console_ns.doc(description="Create external knowledge dataset")
@console_ns.expect( @console_ns.expect(console_ns.models[ExternalDatasetCreatePayload.__name__])
console_ns.model(
"CreateExternalDatasetRequest",
{
"external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"),
"external_knowledge_id": fields.String(required=True, description="External knowledge ID"),
"name": fields.String(required=True, description="Dataset name"),
"description": fields.String(description="Dataset description"),
},
)
)
@console_ns.response(201, "External dataset created successfully", dataset_detail_model) @console_ns.response(201, "External dataset created successfully", dataset_detail_model)
@console_ns.response(400, "Invalid parameters") @console_ns.response(400, "Invalid parameters")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@@ -261,22 +249,8 @@ class ExternalDatasetCreateApi(Resource):
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
parser = ( payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump(exclude_none=True)
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
.add_argument(
"name",
nullable=False,
required=True,
help="name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument("description", type=str, required=False, nullable=True, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@@ -299,16 +273,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
@console_ns.doc("test_external_knowledge_retrieval") @console_ns.doc("test_external_knowledge_retrieval")
@console_ns.doc(description="Test external knowledge retrieval for dataset") @console_ns.doc(description="Test external knowledge retrieval for dataset")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[ExternalHitTestingPayload.__name__])
console_ns.model(
"ExternalHitTestingRequest",
{
"query": fields.String(required=True, description="Query text for testing"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"external_retrieval_model": fields.Raw(description="External retrieval model configuration"),
},
)
)
@console_ns.response(200, "External hit testing completed successfully") @console_ns.response(200, "External hit testing completed successfully")
@console_ns.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters") @console_ns.response(400, "Invalid parameters")
@@ -327,23 +292,16 @@ class ExternalKnowledgeHitTestingApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
parser = ( payload = ExternalHitTestingPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() HitTestingService.hit_testing_args_check(payload.model_dump())
.add_argument("query", type=str, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
)
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
try: try:
response = HitTestingService.external_retrieve( response = HitTestingService.external_retrieve(
dataset=dataset, dataset=dataset,
query=args["query"], query=payload.query,
account=current_user, account=current_user,
external_retrieval_model=args["external_retrieval_model"], external_retrieval_model=payload.external_retrieval_model,
metadata_filtering_conditions=args["metadata_filtering_conditions"], metadata_filtering_conditions=payload.metadata_filtering_conditions,
) )
return response return response
@@ -356,33 +314,13 @@ class BedrockRetrievalApi(Resource):
# this api is only for internal testing # this api is only for internal testing
@console_ns.doc("bedrock_retrieval_test") @console_ns.doc("bedrock_retrieval_test")
@console_ns.doc(description="Bedrock retrieval test (internal use only)") @console_ns.doc(description="Bedrock retrieval test (internal use only)")
@console_ns.expect( @console_ns.expect(console_ns.models[BedrockRetrievalPayload.__name__])
console_ns.model(
"BedrockRetrievalTestRequest",
{
"retrieval_setting": fields.Raw(required=True, description="Retrieval settings"),
"query": fields.String(required=True, description="Query text"),
"knowledge_id": fields.String(required=True, description="Knowledge ID"),
},
)
)
@console_ns.response(200, "Bedrock retrieval test completed") @console_ns.response(200, "Bedrock retrieval test completed")
def post(self): def post(self):
parser = ( payload = BedrockRetrievalPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
.add_argument(
"query",
nullable=False,
required=True,
type=str,
)
.add_argument("knowledge_id", nullable=False, required=True, type=str)
)
args = parser.parse_args()
# Call the knowledge retrieval service # Call the knowledge retrieval service
result = ExternalDatasetTestService.knowledge_retrieval( result = ExternalDatasetTestService.knowledge_retrieval(
args["retrieval_setting"], args["query"], args["knowledge_id"] payload.retrieval_setting, payload.query, payload.knowledge_id
) )
return result, 200 return result, 200

View File

@@ -1,13 +1,17 @@
from flask_restx import Resource, fields from flask_restx import Resource
from controllers.console import console_ns from controllers.common.schema import register_schema_model
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from libs.login import login_required
from controllers.console.wraps import (
from .. import console_ns
from ..datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
from ..wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_rate_limit_check, cloud_edition_billing_rate_limit_check,
setup_required, setup_required,
) )
from libs.login import login_required
register_schema_model(console_ns, HitTestingPayload)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing") @console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
@@ -15,17 +19,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval") @console_ns.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval") @console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[HitTestingPayload.__name__])
console_ns.model(
"HitTestingRequest",
{
"query": fields.String(required=True, description="Query text for testing"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"top_k": fields.Integer(description="Number of top results to return"),
"score_threshold": fields.Float(description="Score threshold for filtering results"),
},
)
)
@console_ns.response(200, "Hit testing completed successfully") @console_ns.response(200, "Hit testing completed successfully")
@console_ns.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters") @console_ns.response(400, "Invalid parameters")
@@ -37,7 +31,8 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str) dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args() payload = HitTestingPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
self.hit_testing_args_check(args) self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args) return self.perform_hit_testing(dataset, args)

View File

@@ -1,6 +1,8 @@
import logging import logging
from typing import Any
from flask_restx import marshal, reqparse from flask_restx import marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
@@ -27,6 +29,12 @@ from services.hit_testing_service import HitTestingService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HitTestingPayload(BaseModel):
query: str = Field(max_length=250)
retrieval_model: dict[str, Any] | None = None
external_retrieval_model: dict[str, Any] | None = None
class DatasetsHitTestingBase: class DatasetsHitTestingBase:
@staticmethod @staticmethod
def get_and_validate_dataset(dataset_id: str): def get_and_validate_dataset(dataset_id: str):
@@ -43,19 +51,9 @@ class DatasetsHitTestingBase:
return dataset return dataset
@staticmethod @staticmethod
def hit_testing_args_check(args): def hit_testing_args_check(args: dict[str, Any]):
HitTestingService.hit_testing_args_check(args) HitTestingService.hit_testing_args_check(args)
@staticmethod
def parse_args():
parser = (
reqparse.RequestParser()
.add_argument("query", type=str, location="json")
.add_argument("retrieval_model", type=dict, required=False, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
return parser.parse_args()
@staticmethod @staticmethod
def perform_hit_testing(dataset, args): def perform_hit_testing(dataset, args):
assert isinstance(current_user, Account) assert isinstance(current_user, Account)

View File

@@ -1,8 +1,10 @@
from typing import Literal from typing import Literal
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields from fields.dataset_fields import dataset_metadata_fields
@@ -15,6 +17,14 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
class MetadataUpdatePayload(BaseModel):
name: str
register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
register_schema_model(console_ns, MetadataUpdatePayload)
@console_ns.route("/datasets/<uuid:dataset_id>/metadata") @console_ns.route("/datasets/<uuid:dataset_id>/metadata")
class DatasetMetadataCreateApi(Resource): class DatasetMetadataCreateApi(Resource):
@setup_required @setup_required
@@ -22,15 +32,10 @@ class DatasetMetadataCreateApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
@marshal_with(dataset_metadata_fields) @marshal_with(dataset_metadata_fields)
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@@ -60,11 +65,11 @@ class DatasetMetadataApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
@marshal_with(dataset_metadata_fields) @marshal_with(dataset_metadata_fields)
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
def patch(self, dataset_id, metadata_id): def patch(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json") payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
args = parser.parse_args() name = payload.name
name = args["name"]
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id) metadata_id_str = str(metadata_id)
@@ -131,6 +136,7 @@ class DocumentMetadataEditApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
@@ -139,11 +145,7 @@ class DocumentMetadataEditApi(Resource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
parser = reqparse.RequestParser().add_argument( metadata_args = MetadataOperationData.model_validate(console_ns.payload or {})
"operation_data", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args) MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@@ -1,20 +1,63 @@
from typing import Any
from flask import make_response, redirect, request from flask import make_response, redirect, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService from services.plugin.oauth_service import OAuthProxyService
class DatasourceCredentialPayload(BaseModel):
name: str | None = Field(default=None, max_length=100)
credentials: dict[str, Any]
class DatasourceCredentialDeletePayload(BaseModel):
credential_id: str
class DatasourceCredentialUpdatePayload(BaseModel):
credential_id: str
name: str | None = Field(default=None, max_length=100)
credentials: dict[str, Any] | None = None
class DatasourceCustomClientPayload(BaseModel):
client_params: dict[str, Any] | None = None
enable_oauth_custom_client: bool | None = None
class DatasourceDefaultPayload(BaseModel):
id: str
class DatasourceUpdateNamePayload(BaseModel):
credential_id: str
name: str = Field(max_length=100)
register_schema_models(
console_ns,
DatasourceCredentialPayload,
DatasourceCredentialDeletePayload,
DatasourceCredentialUpdatePayload,
DatasourceCustomClientPayload,
DatasourceDefaultPayload,
DatasourceUpdateNamePayload,
)
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url") @console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
class DatasourcePluginOAuthAuthorizationUrl(Resource): class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required @setup_required
@@ -121,16 +164,9 @@ class DatasourceOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_datasource = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>") @console_ns.route("/auth/plugin/datasource/<path:provider_id>")
class DatasourceAuth(Resource): class DatasourceAuth(Resource):
@console_ns.expect(parser_datasource) @console_ns.expect(console_ns.models[DatasourceCredentialPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -138,7 +174,7 @@ class DatasourceAuth(Resource):
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
args = parser_datasource.parse_args() payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
@@ -146,8 +182,8 @@ class DatasourceAuth(Resource):
datasource_provider_service.add_datasource_api_key_provider( datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
provider_id=datasource_provider_id, provider_id=datasource_provider_id,
credentials=args["credentials"], credentials=payload.credentials,
name=args["name"], name=payload.name,
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
raise ValueError(str(ex)) raise ValueError(str(ex))
@@ -169,14 +205,9 @@ class DatasourceAuth(Resource):
return {"result": datasources}, 200 return {"result": datasources}, 200
parser_datasource_delete = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
class DatasourceAuthDeleteApi(Resource): class DatasourceAuthDeleteApi(Resource):
@console_ns.expect(parser_datasource_delete) @console_ns.expect(console_ns.models[DatasourceCredentialDeletePayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -188,28 +219,20 @@ class DatasourceAuthDeleteApi(Resource):
plugin_id = datasource_provider_id.plugin_id plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name provider_name = datasource_provider_id.provider_name
args = parser_datasource_delete.parse_args() payload = DatasourceCredentialDeletePayload.model_validate(console_ns.payload or {})
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials( datasource_provider_service.remove_datasource_credentials(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
auth_id=args["credential_id"], auth_id=payload.credential_id,
provider=provider_name, provider=provider_name,
plugin_id=plugin_id, plugin_id=plugin_id,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
parser_datasource_update = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
class DatasourceAuthUpdateApi(Resource): class DatasourceAuthUpdateApi(Resource):
@console_ns.expect(parser_datasource_update) @console_ns.expect(console_ns.models[DatasourceCredentialUpdatePayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -218,16 +241,16 @@ class DatasourceAuthUpdateApi(Resource):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
args = parser_datasource_update.parse_args() payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials( datasource_provider_service.update_datasource_credentials(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
auth_id=args["credential_id"], auth_id=payload.credential_id,
provider=datasource_provider_id.provider_name, provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id, plugin_id=datasource_provider_id.plugin_id,
credentials=args.get("credentials", {}), credentials=payload.credentials or {},
name=args.get("name", None), name=payload.name,
) )
return {"result": "success"}, 201 return {"result": "success"}, 201
@@ -258,16 +281,9 @@ class DatasourceHardCodeAuthListApi(Resource):
return {"result": jsonable_encoder(datasources)}, 200 return {"result": jsonable_encoder(datasources)}, 200
parser_datasource_custom = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
class DatasourceAuthOauthCustomClient(Resource): class DatasourceAuthOauthCustomClient(Resource):
@console_ns.expect(parser_datasource_custom) @console_ns.expect(console_ns.models[DatasourceCustomClientPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -275,14 +291,14 @@ class DatasourceAuthOauthCustomClient(Resource):
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
args = parser_datasource_custom.parse_args() payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params( datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}), client_params=payload.client_params or {},
enabled=args.get("enable_oauth_custom_client", False), enabled=payload.enable_oauth_custom_client or False,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
@@ -301,12 +317,9 @@ class DatasourceAuthOauthCustomClient(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
class DatasourceAuthDefaultApi(Resource): class DatasourceAuthDefaultApi(Resource):
@console_ns.expect(parser_default) @console_ns.expect(console_ns.models[DatasourceDefaultPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -314,27 +327,20 @@ class DatasourceAuthDefaultApi(Resource):
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
args = parser_default.parse_args() payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider( datasource_provider_service.set_default_datasource_provider(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
credential_id=args["id"], credential_id=payload.id,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
parser_update_name = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
class DatasourceUpdateProviderNameApi(Resource): class DatasourceUpdateProviderNameApi(Resource):
@console_ns.expect(parser_update_name) @console_ns.expect(console_ns.models[DatasourceUpdateNamePayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -342,13 +348,13 @@ class DatasourceUpdateProviderNameApi(Resource):
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
args = parser_update_name.parse_args() payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name( datasource_provider_service.update_datasource_provider_name(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
name=args["name"], name=payload.name,
credential_id=args["credential_id"], credential_id=payload.credential_id,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@@ -1,9 +1,11 @@
import logging import logging
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
@@ -20,18 +22,6 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description: str) -> str:
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@console_ns.route("/rag/pipeline/templates") @console_ns.route("/rag/pipeline/templates")
class PipelineTemplateListApi(Resource): class PipelineTemplateListApi(Resource):
@setup_required @setup_required
@@ -59,6 +49,15 @@ class PipelineTemplateDetailApi(Resource):
return pipeline_template, 200 return pipeline_template, 200
class Payload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field(default="", max_length=400)
icon_info: dict[str, object] | None = None
register_schema_models(console_ns, Payload)
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>") @console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
class CustomizedPipelineTemplateApi(Resource): class CustomizedPipelineTemplateApi(Resource):
@setup_required @setup_required
@@ -66,31 +65,8 @@ class CustomizedPipelineTemplateApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def patch(self, template_id: str): def patch(self, template_id: str):
parser = ( payload = Payload.model_validate(console_ns.payload or {})
reqparse.RequestParser() pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=_validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200 return 200
@@ -119,36 +95,14 @@ class CustomizedPipelineTemplateApi(Resource):
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish") @console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
class PublishCustomizedPipelineTemplateApi(Resource): class PublishCustomizedPipelineTemplateApi(Resource):
@console_ns.expect(console_ns.models[Payload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
@knowledge_pipeline_publish_enabled @knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str): def post(self, pipeline_id: str):
parser = ( payload = Payload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=_validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
)
args = parser.parse_args()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args) rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
return {"result": "success"} return {"result": "success"}

View File

@@ -1,8 +1,10 @@
from flask_restx import Resource, marshal, reqparse from flask_restx import Resource, marshal
from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
import services import services
from controllers.common.schema import register_schema_model
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import ( from controllers.console.wraps import (
@@ -19,22 +21,22 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineDatasetImportPayload(BaseModel):
yaml_content: str
register_schema_model(console_ns, RagPipelineDatasetImportPayload)
@console_ns.route("/rag/pipeline/dataset") @console_ns.route("/rag/pipeline/dataset")
class CreateRagPipelineDatasetApi(Resource): class CreateRagPipelineDatasetApi(Resource):
@console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
parser = reqparse.RequestParser().add_argument( payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
"yaml_content",
type=str,
nullable=False,
required=True,
help="yaml_content is required.",
)
args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@@ -49,7 +51,7 @@ class CreateRagPipelineDatasetApi(Resource):
), ),
permission=DatasetPermissionEnum.ONLY_ME, permission=DatasetPermissionEnum.ONLY_ME,
partial_member_list=None, partial_member_list=None,
yaml_content=args["yaml_content"], yaml_content=payload.yaml_content,
) )
try: try:
with Session(db.engine) as session: with Session(db.engine) as session:

View File

@@ -1,11 +1,13 @@
import logging import logging
from typing import NoReturn from typing import Any, NoReturn
from flask import Response from flask import Response, request
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
DraftWorkflowNotExist, DraftWorkflowNotExist,
@@ -33,19 +35,21 @@ logger = logging.getLogger(__name__)
def _create_pagination_parser(): def _create_pagination_parser():
parser = ( class PaginationQuery(BaseModel):
reqparse.RequestParser() page: int = Field(default=1, ge=1, le=100_000)
.add_argument( limit: int = Field(default=20, ge=1, le=100)
"page",
type=inputs.int_range(1, 100_000), register_schema_models(console_ns, PaginationQuery)
required=False,
default=1, return PaginationQuery
location="args",
help="the page of data requested",
) class WorkflowDraftVariablePatchPayload(BaseModel):
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") name: str | None = None
) value: Any | None = None
return parser
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]: def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
@@ -93,8 +97,8 @@ class RagPipelineVariableCollectionApi(Resource):
""" """
Get draft workflow Get draft workflow
""" """
parser = _create_pagination_parser() pagination = _create_pagination_parser()
args = parser.parse_args() query = pagination.model_validate(request.args.to_dict())
# fetch draft workflow by app_model # fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
@@ -109,8 +113,8 @@ class RagPipelineVariableCollectionApi(Resource):
) )
workflow_vars = draft_var_srv.list_variables_without_values( workflow_vars = draft_var_srv.list_variables_without_values(
app_id=pipeline.id, app_id=pipeline.id,
page=args.page, page=query.page,
limit=args.limit, limit=query.limit,
) )
return workflow_vars return workflow_vars
@@ -186,6 +190,7 @@ class RagPipelineVariableApi(Resource):
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: str): def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types: # Request payload for file types:
# #
@@ -208,16 +213,11 @@ class RagPipelineVariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# } # }
parser = (
reqparse.RequestParser()
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=db.session(), session=db.session(),
) )
args = parser.parse_args(strict=True) payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
variable = draft_var_srv.get_variable(variable_id=variable_id) variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None: if variable is None:

View File

@@ -1,6 +1,9 @@
from flask_restx import Resource, marshal_with, reqparse # type: ignore from flask import request
from flask_restx import Resource, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import ( from controllers.console.wraps import (
@@ -16,6 +19,25 @@ from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineImportPayload(BaseModel):
mode: str
yaml_content: str | None = None
yaml_url: str | None = None
name: str | None = None
description: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
pipeline_id: str | None = None
class IncludeSecretQuery(BaseModel):
include_secret: str = Field(default="false")
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
@console_ns.route("/rag/pipelines/imports") @console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource): class RagPipelineImportApi(Resource):
@setup_required @setup_required
@@ -23,23 +45,11 @@ class RagPipelineImportApi(Resource):
@account_initialization_required @account_initialization_required
@edit_permission_required @edit_permission_required
@marshal_with(pipeline_import_fields) @marshal_with(pipeline_import_fields)
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
def post(self): def post(self):
# Check user role first # Check user role first
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("pipeline_id", type=str, location="json")
)
args = parser.parse_args()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
@@ -48,11 +58,11 @@ class RagPipelineImportApi(Resource):
account = current_user account = current_user
result = import_service.import_rag_pipeline( result = import_service.import_rag_pipeline(
account=account, account=account,
import_mode=args["mode"], import_mode=payload.mode,
yaml_content=args.get("yaml_content"), yaml_content=payload.yaml_content,
yaml_url=args.get("yaml_url"), yaml_url=payload.yaml_url,
pipeline_id=args.get("pipeline_id"), pipeline_id=payload.pipeline_id,
dataset_name=args.get("name"), dataset_name=payload.name,
) )
session.commit() session.commit()
@@ -114,13 +124,12 @@ class RagPipelineExportApi(Resource):
@edit_permission_required @edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
# Add include_secret params # Add include_secret params
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args") query = IncludeSecretQuery.model_validate(request.args.to_dict())
args = parser.parse_args()
with Session(db.engine) as session: with Session(db.engine) as session:
export_service = RagPipelineDslService(session) export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl( result = export_service.export_rag_pipeline_dsl(
pipeline=pipeline, include_secret=args["include_secret"] == "true" pipeline=pipeline, include_secret=query.include_secret == "true"
) )
return {"data": result}, 200 return {"data": result}, 200

View File

@@ -1,14 +1,16 @@
import json import json
import logging import logging
from typing import cast from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request from flask import abort, request
from flask_restx import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore from flask_restx import Resource, marshal_with # type: ignore
from flask_restx.inputs import int_range # type: ignore from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
ConversationCompletedError, ConversationCompletedError,
@@ -36,7 +38,7 @@ from fields.workflow_run_fields import (
workflow_run_pagination_fields, workflow_run_pagination_fields,
) )
from libs import helper from libs import helper
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField
from libs.login import current_account_with_tenant, current_user, login_required from libs.login import current_account_with_tenant, current_user, login_required
from models import Account from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
@@ -51,6 +53,91 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DraftWorkflowSyncPayload(BaseModel):
graph: dict[str, Any]
hash: str | None = None
environment_variables: list[dict[str, Any]] | None = None
conversation_variables: list[dict[str, Any]] | None = None
rag_pipeline_variables: list[dict[str, Any]] | None = None
features: dict[str, Any] | None = None
class NodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class NodeRunRequiredPayload(BaseModel):
inputs: dict[str, Any]
class DatasourceNodeRunPayload(BaseModel):
inputs: dict[str, Any]
datasource_type: str
credential_id: str | None = None
class DraftWorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
datasource_type: str
datasource_info_list: list[dict[str, Any]]
start_node_id: str
class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
is_preview: bool = False
response_mode: Literal["streaming", "blocking"] = "streaming"
original_document_id: str | None = None
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class NodeIdQuery(BaseModel):
node_id: str
class WorkflowRunQuery(BaseModel):
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
class DatasourceVariablesPayload(BaseModel):
datasource_type: str
datasource_info: dict[str, Any]
start_node_id: str
start_node_title: str
register_schema_models(
console_ns,
DraftWorkflowSyncPayload,
NodeRunPayload,
NodeRunRequiredPayload,
DatasourceNodeRunPayload,
DraftWorkflowRunPayload,
PublishedWorkflowRunPayload,
DefaultBlockConfigQuery,
WorkflowListQuery,
WorkflowUpdatePayload,
NodeIdQuery,
WorkflowRunQuery,
DatasourceVariablesPayload,
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
class DraftRagPipelineApi(Resource): class DraftRagPipelineApi(Resource):
@setup_required @setup_required
@@ -88,15 +175,7 @@ class DraftRagPipelineApi(Resource):
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type: if "application/json" in content_type:
parser = ( payload_dict = console_ns.payload or {}
reqparse.RequestParser()
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
.add_argument("hash", type=str, required=False, location="json")
.add_argument("environment_variables", type=list, required=False, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
)
args = parser.parse_args()
elif "text/plain" in content_type: elif "text/plain" in content_type:
try: try:
data = json.loads(request.data.decode("utf-8")) data = json.loads(request.data.decode("utf-8"))
@@ -106,7 +185,7 @@ class DraftRagPipelineApi(Resource):
if not isinstance(data.get("graph"), dict): if not isinstance(data.get("graph"), dict):
raise ValueError("graph is not a dict") raise ValueError("graph is not a dict")
args = { payload_dict = {
"graph": data.get("graph"), "graph": data.get("graph"),
"features": data.get("features"), "features": data.get("features"),
"hash": data.get("hash"), "hash": data.get("hash"),
@@ -119,24 +198,26 @@ class DraftRagPipelineApi(Resource):
else: else:
abort(415) abort(415)
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
try: try:
environment_variables_list = args.get("environment_variables") or [] environment_variables_list = payload.environment_variables or []
environment_variables = [ environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
] ]
conversation_variables_list = args.get("conversation_variables") or [] conversation_variables_list = payload.conversation_variables or []
conversation_variables = [ conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
] ]
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.sync_draft_workflow( workflow = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline, pipeline=pipeline,
graph=args["graph"], graph=payload.graph,
unique_hash=args.get("hash"), unique_hash=payload.hash,
account=current_user, account=current_user,
environment_variables=environment_variables, environment_variables=environment_variables,
conversation_variables=conversation_variables, conversation_variables=conversation_variables,
rag_pipeline_variables=args.get("rag_pipeline_variables") or [], rag_pipeline_variables=payload.rag_pipeline_variables or [],
) )
except WorkflowHashNotEqualError: except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync() raise DraftWorkflowNotSync()
@@ -148,12 +229,9 @@ class DraftRagPipelineApi(Resource):
} }
parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class RagPipelineDraftRunIterationNodeApi(Resource): class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.expect(parser_run) @console_ns.expect(console_ns.models[NodeRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -166,7 +244,8 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_run.parse_args() payload = NodeRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try: try:
response = PipelineGenerateService.generate_single_iteration( response = PipelineGenerateService.generate_single_iteration(
@@ -187,7 +266,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class RagPipelineDraftRunLoopNodeApi(Resource): class RagPipelineDraftRunLoopNodeApi(Resource):
@console_ns.expect(parser_run) @console_ns.expect(console_ns.models[NodeRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -200,7 +279,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_run.parse_args() payload = NodeRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try: try:
response = PipelineGenerateService.generate_single_loop( response = PipelineGenerateService.generate_single_loop(
@@ -219,18 +299,9 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
raise InternalServerError() raise InternalServerError()
parser_draft_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
class DraftRagPipelineRunApi(Resource): class DraftRagPipelineRunApi(Resource):
@console_ns.expect(parser_draft_run) @console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -243,7 +314,8 @@ class DraftRagPipelineRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_draft_run.parse_args() payload = DraftWorkflowRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
try: try:
response = PipelineGenerateService.generate( response = PipelineGenerateService.generate(
@@ -259,21 +331,9 @@ class DraftRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description) raise InvokeRateLimitHttpError(ex.description)
parser_published_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
.add_argument("original_document_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
class PublishedRagPipelineRunApi(Resource): class PublishedRagPipelineRunApi(Resource):
@console_ns.expect(parser_published_run) @console_ns.expect(console_ns.models[PublishedWorkflowRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -286,16 +346,16 @@ class PublishedRagPipelineRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_published_run.parse_args() payload = PublishedWorkflowRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
streaming = args["response_mode"] == "streaming" streaming = payload.response_mode == "streaming"
try: try:
response = PipelineGenerateService.generate( response = PipelineGenerateService.generate(
pipeline=pipeline, pipeline=pipeline,
user=current_user, user=current_user,
args=args, args=args,
invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED, invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED,
streaming=streaming, streaming=streaming,
) )
@@ -387,17 +447,9 @@ class PublishedRagPipelineRunApi(Resource):
# #
# return result # return result
# #
parser_rag_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource): class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.expect(parser_rag_run) @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -410,14 +462,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_rag_run.parse_args() payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
return helper.compact_generate_response( return helper.compact_generate_response(
@@ -425,11 +470,11 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
rag_pipeline_service.run_datasource_workflow_node( rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline, pipeline=pipeline,
node_id=node_id, node_id=node_id,
user_inputs=inputs, user_inputs=payload.inputs,
account=current_user, account=current_user,
datasource_type=datasource_type, datasource_type=payload.datasource_type,
is_published=False, is_published=False,
credential_id=args.get("credential_id"), credential_id=payload.credential_id,
) )
) )
) )
@@ -437,7 +482,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
class RagPipelineDraftDatasourceNodeRunApi(Resource): class RagPipelineDraftDatasourceNodeRunApi(Resource):
@console_ns.expect(parser_rag_run) @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@edit_permission_required @edit_permission_required
@@ -450,14 +495,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_rag_run.parse_args() payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
return helper.compact_generate_response( return helper.compact_generate_response(
@@ -465,24 +503,19 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
rag_pipeline_service.run_datasource_workflow_node( rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline, pipeline=pipeline,
node_id=node_id, node_id=node_id,
user_inputs=inputs, user_inputs=payload.inputs,
account=current_user, account=current_user,
datasource_type=datasource_type, datasource_type=payload.datasource_type,
is_published=False, is_published=False,
credential_id=args.get("credential_id"), credential_id=payload.credential_id,
) )
) )
) )
parser_run_api = reqparse.RequestParser().add_argument(
"inputs", type=dict, required=True, nullable=False, location="json"
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
class RagPipelineDraftNodeRunApi(Resource): class RagPipelineDraftNodeRunApi(Resource):
@console_ns.expect(parser_run_api) @console_ns.expect(console_ns.models[NodeRunRequiredPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@edit_permission_required @edit_permission_required
@@ -496,11 +529,8 @@ class RagPipelineDraftNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_run_api.parse_args() payload = NodeRunRequiredPayload.model_validate(console_ns.payload or {})
inputs = payload.inputs
inputs = args.get("inputs")
if inputs == None:
raise ValueError("missing inputs")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.run_draft_workflow_node( workflow_node_execution = rag_pipeline_service.run_draft_workflow_node(
@@ -602,12 +632,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
return rag_pipeline_service.get_default_block_configs() return rag_pipeline_service.get_default_block_configs()
parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultRagPipelineBlockConfigApi(Resource): class DefaultRagPipelineBlockConfigApi(Resource):
@console_ns.expect(parser_default)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -617,14 +643,12 @@ class DefaultRagPipelineBlockConfigApi(Resource):
""" """
Get default block config Get default block config
""" """
args = parser_default.parse_args() query = DefaultBlockConfigQuery.model_validate(request.args.to_dict())
q = args.get("q")
filters = None filters = None
if q: if query.q:
try: try:
filters = json.loads(args.get("q", "")) filters = json.loads(query.q)
except json.JSONDecodeError: except json.JSONDecodeError:
raise ValueError("Invalid filters") raise ValueError("Invalid filters")
@@ -633,18 +657,8 @@ class DefaultRagPipelineBlockConfigApi(Resource):
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
parser_wf = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
class PublishedAllRagPipelineApi(Resource): class PublishedAllRagPipelineApi(Resource):
@console_ns.expect(parser_wf)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -657,16 +671,16 @@ class PublishedAllRagPipelineApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_wf.parse_args() query = WorkflowListQuery.model_validate(request.args.to_dict())
page = args["page"]
limit = args["limit"] page = query.page
user_id = args.get("user_id") limit = query.limit
named_only = args.get("named_only", False) user_id = query.user_id
named_only = query.named_only
if user_id: if user_id:
if user_id != current_user.id: if user_id != current_user.id:
raise Forbidden() raise Forbidden()
user_id = cast(str, user_id)
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session: with Session(db.engine) as session:
@@ -687,16 +701,8 @@ class PublishedAllRagPipelineApi(Resource):
} }
parser_wf_id = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource): class RagPipelineByIdApi(Resource):
@console_ns.expect(parser_wf_id)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -710,20 +716,8 @@ class RagPipelineByIdApi(Resource):
# Check permission # Check permission
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_wf_id.parse_args() payload = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
update_data = payload.model_dump(exclude_unset=True)
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
# Prepare update data
update_data = {}
if args.get("marked_name") is not None:
update_data["marked_name"] = args["marked_name"]
if args.get("marked_comment") is not None:
update_data["marked_comment"] = args["marked_comment"]
if not update_data: if not update_data:
return {"message": "No valid fields to update"}, 400 return {"message": "No valid fields to update"}, 400
@@ -749,12 +743,8 @@ class RagPipelineByIdApi(Resource):
return workflow return workflow
parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
class PublishedRagPipelineSecondStepApi(Resource): class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -764,10 +754,8 @@ class PublishedRagPipelineSecondStepApi(Resource):
""" """
Get second step parameters of rag pipeline Get second step parameters of rag pipeline
""" """
args = parser_parameters.parse_args() query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = args.get("node_id") node_id = query.node_id
if not node_id:
raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
return { return {
@@ -777,7 +765,6 @@ class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
class PublishedRagPipelineFirstStepApi(Resource): class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -787,10 +774,8 @@ class PublishedRagPipelineFirstStepApi(Resource):
""" """
Get first step parameters of rag pipeline Get first step parameters of rag pipeline
""" """
args = parser_parameters.parse_args() query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = args.get("node_id") node_id = query.node_id
if not node_id:
raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
return { return {
@@ -800,7 +785,6 @@ class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
class DraftRagPipelineFirstStepApi(Resource): class DraftRagPipelineFirstStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -810,10 +794,8 @@ class DraftRagPipelineFirstStepApi(Resource):
""" """
Get first step parameters of rag pipeline Get first step parameters of rag pipeline
""" """
args = parser_parameters.parse_args() query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = args.get("node_id") node_id = query.node_id
if not node_id:
raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
return { return {
@@ -823,7 +805,6 @@ class DraftRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
class DraftRagPipelineSecondStepApi(Resource): class DraftRagPipelineSecondStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -833,10 +814,8 @@ class DraftRagPipelineSecondStepApi(Resource):
""" """
Get second step parameters of rag pipeline Get second step parameters of rag pipeline
""" """
args = parser_parameters.parse_args() query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = args.get("node_id") node_id = query.node_id
if not node_id:
raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
@@ -845,16 +824,8 @@ class DraftRagPipelineSecondStepApi(Resource):
} }
parser_wf_run = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
class RagPipelineWorkflowRunListApi(Resource): class RagPipelineWorkflowRunListApi(Resource):
@console_ns.expect(parser_wf_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -864,7 +835,16 @@ class RagPipelineWorkflowRunListApi(Resource):
""" """
Get workflow run list Get workflow run list
""" """
args = parser_wf_run.parse_args() query = WorkflowRunQuery.model_validate(
{
"last_id": request.args.get("last_id"),
"limit": request.args.get("limit", type=int, default=20),
}
)
args = {
"last_id": str(query.last_id) if query.last_id else None,
"limit": query.limit,
}
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args) result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
@@ -964,18 +944,9 @@ class RagPipelineTransformApi(Resource):
return result return result
parser_var = (
reqparse.RequestParser()
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info", type=dict, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("start_node_title", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
class RagPipelineDatasourceVariableApi(Resource): class RagPipelineDatasourceVariableApi(Resource):
@console_ns.expect(parser_var) @console_ns.expect(console_ns.models[DatasourceVariablesPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@@ -987,7 +958,7 @@ class RagPipelineDatasourceVariableApi(Resource):
Set datasource variables Set datasource variables
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_var.parse_args() args = DatasourceVariablesPayload.model_validate(console_ns.payload or {}).model_dump()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.set_datasource_variables( workflow_node_execution = rag_pipeline_service.set_datasource_variables(

View File

@@ -1,5 +1,10 @@
from flask_restx import Resource, fields, reqparse from typing import Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.datasets.error import WebsiteCrawlError from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
@@ -7,48 +12,35 @@ from libs.login import login_required
from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService
class WebsiteCrawlPayload(BaseModel):
provider: Literal["firecrawl", "watercrawl", "jinareader"]
url: str
options: dict[str, object]
class WebsiteCrawlStatusQuery(BaseModel):
provider: Literal["firecrawl", "watercrawl", "jinareader"]
register_schema_models(console_ns, WebsiteCrawlPayload, WebsiteCrawlStatusQuery)
@console_ns.route("/website/crawl") @console_ns.route("/website/crawl")
class WebsiteCrawlApi(Resource): class WebsiteCrawlApi(Resource):
@console_ns.doc("crawl_website") @console_ns.doc("crawl_website")
@console_ns.doc(description="Crawl website content") @console_ns.doc(description="Crawl website content")
@console_ns.expect( @console_ns.expect(console_ns.models[WebsiteCrawlPayload.__name__])
console_ns.model(
"WebsiteCrawlRequest",
{
"provider": fields.String(
required=True,
description="Crawl provider (firecrawl/watercrawl/jinareader)",
enum=["firecrawl", "watercrawl", "jinareader"],
),
"url": fields.String(required=True, description="URL to crawl"),
"options": fields.Raw(required=True, description="Crawl options"),
},
)
)
@console_ns.response(200, "Website crawl initiated successfully") @console_ns.response(200, "Website crawl initiated successfully")
@console_ns.response(400, "Invalid crawl parameters") @console_ns.response(400, "Invalid crawl parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = ( payload = WebsiteCrawlPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"provider",
type=str,
choices=["firecrawl", "watercrawl", "jinareader"],
required=True,
nullable=True,
location="json",
)
.add_argument("url", type=str, required=True, nullable=True, location="json")
.add_argument("options", type=dict, required=True, nullable=True, location="json")
)
args = parser.parse_args()
# Create typed request and validate # Create typed request and validate
try: try:
api_request = WebsiteCrawlApiRequest.from_args(args) api_request = WebsiteCrawlApiRequest.from_args(payload.model_dump())
except ValueError as e: except ValueError as e:
raise WebsiteCrawlError(str(e)) raise WebsiteCrawlError(str(e))
@@ -65,6 +57,7 @@ class WebsiteCrawlStatusApi(Resource):
@console_ns.doc("get_crawl_status") @console_ns.doc("get_crawl_status")
@console_ns.doc(description="Get website crawl status") @console_ns.doc(description="Get website crawl status")
@console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"}) @console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"})
@console_ns.expect(console_ns.models[WebsiteCrawlStatusQuery.__name__])
@console_ns.response(200, "Crawl status retrieved successfully") @console_ns.response(200, "Crawl status retrieved successfully")
@console_ns.response(404, "Crawl job not found") @console_ns.response(404, "Crawl job not found")
@console_ns.response(400, "Invalid provider") @console_ns.response(400, "Invalid provider")
@@ -72,14 +65,11 @@ class WebsiteCrawlStatusApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, job_id: str): def get(self, job_id: str):
parser = reqparse.RequestParser().add_argument( args = WebsiteCrawlStatusQuery.model_validate(request.args.to_dict())
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
)
args = parser.parse_args()
# Create typed request and validate # Create typed request and validate
try: try:
api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id) api_request = WebsiteCrawlStatusApiRequest.from_args(args.model_dump(), job_id)
except ValueError as e: except ValueError as e:
raise WebsiteCrawlError(str(e)) raise WebsiteCrawlError(str(e))

View File

@@ -1,9 +1,11 @@
import logging import logging
from flask import request from flask import request
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
from controllers.common.schema import register_schema_model
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
AudioTooLargeError, AudioTooLargeError,
@@ -31,6 +33,16 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(console_ns, TextToAudioPayload)
@console_ns.route( @console_ns.route(
"/installed-apps/<uuid:installed_app_id>/audio-to-text", "/installed-apps/<uuid:installed_app_id>/audio-to-text",
endpoint="installed_app_audio", endpoint="installed_app_audio",
@@ -76,23 +88,15 @@ class ChatAudioApi(InstalledAppResource):
endpoint="installed_app_text", endpoint="installed_app_text",
) )
class ChatTextApi(InstalledAppResource): class ChatTextApi(InstalledAppResource):
@console_ns.expect(console_ns.models[TextToAudioPayload.__name__])
def post(self, installed_app): def post(self, installed_app):
from flask_restx import reqparse
app_model = installed_app.app app_model = installed_app.app
try: try:
parser = ( payload = TextToAudioPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("message_id", type=str, required=False, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
message_id = args.get("message_id", None) message_id = payload.message_id
text = args.get("text", None) text = payload.text
voice = args.get("voice", None) voice = payload.voice
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
return response return response

View File

@@ -1,9 +1,12 @@
import logging import logging
from typing import Any, Literal
from uuid import UUID
from flask_restx import reqparse from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
CompletionRequestError, CompletionRequestError,
@@ -25,7 +28,6 @@ from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value
from libs.login import current_user from libs.login import current_user
from models import Account from models import Account
from models.model import AppMode from models.model import AppMode
@@ -38,28 +40,42 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CompletionMessagePayload(BaseModel):
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
retriever_from: str = Field(default="explore_app")
class ChatMessagePayload(BaseModel):
inputs: dict[str, Any]
query: str
files: list[dict[str, Any]] | None = None
conversation_id: UUID | None = None
parent_message_id: UUID | None = None
retriever_from: str = Field(default="explore_app")
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
# define completion api for user # define completion api for user
@console_ns.route( @console_ns.route(
"/installed-apps/<uuid:installed_app_id>/completion-messages", "/installed-apps/<uuid:installed_app_id>/completion-messages",
endpoint="installed_app_completion", endpoint="installed_app_completion",
) )
class CompletionApi(InstalledAppResource): class CompletionApi(InstalledAppResource):
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != AppMode.COMPLETION: if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError() raise NotCompletionAppError()
parser = ( payload = CompletionMessagePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump(exclude_none=True)
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args()
streaming = args["response_mode"] == "streaming" streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False args["auto_generate_name"] = False
installed_app.last_used_at = naive_utc_now() installed_app.last_used_at = naive_utc_now()
@@ -123,22 +139,15 @@ class CompletionStopApi(InstalledAppResource):
endpoint="installed_app_chat_completion", endpoint="installed_app_chat_completion",
) )
class ChatApi(InstalledAppResource): class ChatApi(InstalledAppResource):
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = ( payload = ChatMessagePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump(exclude_none=True)
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args()
args["auto_generate_name"] = False args["auto_generate_name"] = False

View File

@@ -1,14 +1,18 @@
from flask_restx import marshal_with, reqparse from typing import Any
from flask_restx.inputs import int_range from uuid import UUID
from flask import request
from flask_restx import marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console.explore.error import NotChatAppError from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from libs.login import current_user from libs.login import current_user
from models import Account from models import Account
from models.model import AppMode from models.model import AppMode
@@ -19,29 +23,44 @@ from services.web_conversation_service import WebConversationService
from .. import console_ns from .. import console_ns
class ConversationListQuery(BaseModel):
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
pinned: bool | None = None
class ConversationRenamePayload(BaseModel):
name: str
auto_generate: bool = False
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
@console_ns.route( @console_ns.route(
"/installed-apps/<uuid:installed_app_id>/conversations", "/installed-apps/<uuid:installed_app_id>/conversations",
endpoint="installed_app_conversations", endpoint="installed_app_conversations",
) )
class ConversationListApi(InstalledAppResource): class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
def get(self, installed_app): def get(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = ( raw_args: dict[str, Any] = {
reqparse.RequestParser() "last_id": request.args.get("last_id"),
.add_argument("last_id", type=uuid_value, location="args") "limit": request.args.get("limit", default=20, type=int),
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") "pinned": request.args.get("pinned"),
.add_argument("pinned", type=str, choices=["true", "false", None], location="args") }
) if raw_args["last_id"] is None:
args = parser.parse_args() raw_args["last_id"] = None
pinned_value = raw_args["pinned"]
pinned = None if isinstance(pinned_value, str):
if "pinned" in args and args["pinned"] is not None: raw_args["pinned"] = pinned_value == "true"
pinned = args["pinned"] == "true" args = ConversationListQuery.model_validate(raw_args)
try: try:
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
@@ -51,10 +70,10 @@ class ConversationListApi(InstalledAppResource):
session=session, session=session,
app_model=app_model, app_model=app_model,
user=current_user, user=current_user,
last_id=args["last_id"], last_id=str(args.last_id) if args.last_id else None,
limit=args["limit"], limit=args.limit,
invoke_from=InvokeFrom.EXPLORE, invoke_from=InvokeFrom.EXPLORE,
pinned=pinned, pinned=args.pinned,
) )
except LastConversationNotExistsError: except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")
@@ -88,6 +107,7 @@ class ConversationApi(InstalledAppResource):
) )
class ConversationRenameApi(InstalledAppResource): class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
def post(self, installed_app, c_id): def post(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
@@ -96,18 +116,13 @@ class ConversationRenameApi(InstalledAppResource):
conversation_id = str(c_id) conversation_id = str(c_id)
parser = ( payload = ConversationRenamePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("name", type=str, required=False, location="json")
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
)
args = parser.parse_args()
try: try:
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance") raise ValueError("current_user must be an Account instance")
return ConversationService.rename( return ConversationService.rename(
app_model, conversation_id, current_user, args["name"], args["auto_generate"] app_model, conversation_id, current_user, payload.name, payload.auto_generate
) )
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")

View File

@@ -1,9 +1,13 @@
import logging import logging
from typing import Literal
from uuid import UUID
from flask_restx import marshal_with, reqparse from flask import request
from flask_restx.inputs import int_range from flask_restx import marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console.app.error import ( from controllers.console.app.error import (
AppMoreLikeThisDisabledError, AppMoreLikeThisDisabledError,
CompletionRequestError, CompletionRequestError,
@@ -22,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper from libs import helper
from libs.helper import uuid_value
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
@@ -40,12 +43,31 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUID
first_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
class MoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"]
register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, MoreLikeThisQuery)
@console_ns.route( @console_ns.route(
"/installed-apps/<uuid:installed_app_id>/messages", "/installed-apps/<uuid:installed_app_id>/messages",
endpoint="installed_app_messages", endpoint="installed_app_messages",
) )
class MessageListApi(InstalledAppResource): class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
def get(self, installed_app): def get(self, installed_app):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
@@ -53,18 +75,15 @@ class MessageListApi(InstalledAppResource):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
args = MessageListQuery.model_validate(request.args.to_dict())
parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
try: try:
return MessageService.pagination_by_first_id( return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] app_model,
current_user,
str(args.conversation_id),
str(args.first_id) if args.first_id else None,
args.limit,
) )
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@@ -77,26 +96,22 @@ class MessageListApi(InstalledAppResource):
endpoint="installed_app_message_feedback", endpoint="installed_app_message_feedback",
) )
class MessageFeedbackApi(InstalledAppResource): class MessageFeedbackApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
def post(self, installed_app, message_id): def post(self, installed_app, message_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
message_id = str(message_id) message_id = str(message_id)
parser = ( payload = MessageFeedbackPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
.add_argument("content", type=str, location="json")
)
args = parser.parse_args()
try: try:
MessageService.create_feedback( MessageService.create_feedback(
app_model=app_model, app_model=app_model,
message_id=message_id, message_id=message_id,
user=current_user, user=current_user,
rating=args.get("rating"), rating=payload.rating,
content=args.get("content"), content=payload.content,
) )
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@@ -109,6 +124,7 @@ class MessageFeedbackApi(InstalledAppResource):
endpoint="installed_app_more_like_this", endpoint="installed_app_more_like_this",
) )
class MessageMoreLikeThisApi(InstalledAppResource): class MessageMoreLikeThisApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
@@ -117,12 +133,9 @@ class MessageMoreLikeThisApi(InstalledAppResource):
message_id = str(message_id) message_id = str(message_id)
parser = reqparse.RequestParser().add_argument( args = MoreLikeThisQuery.model_validate(request.args.to_dict())
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
)
args = parser.parse_args()
streaming = args["response_mode"] == "streaming" streaming = args.response_mode == "streaming"
try: try:
response = AppGenerateService.generate_more_like_this( response = AppGenerateService.generate_more_like_this(

View File

@@ -1,16 +1,33 @@
from flask_restx import fields, marshal_with, reqparse from uuid import UUID
from flask_restx.inputs import int_range
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUID
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
feedback_fields = {"rating": fields.String} feedback_fields = {"rating": fields.String}
message_fields = { message_fields = {
@@ -33,32 +50,33 @@ class SavedMessageListApi(InstalledAppResource):
} }
@marshal_with(saved_message_infinite_scroll_pagination_fields) @marshal_with(saved_message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
def get(self, installed_app): def get(self, installed_app):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
parser = ( args = SavedMessageListQuery.model_validate(request.args.to_dict())
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args") return SavedMessageService.pagination_by_last_id(
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") app_model,
current_user,
str(args.last_id) if args.last_id else None,
args.limit,
) )
args = parser.parse_args()
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
def post(self, installed_app): def post(self, installed_app):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json") payload = SavedMessageCreatePayload.model_validate(console_ns.payload or {})
args = parser.parse_args()
try: try:
SavedMessageService.save(app_model, current_user, args["message_id"]) SavedMessageService.save(app_model, current_user, str(payload.message_id))
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")

View File

@@ -1,8 +1,10 @@
import logging import logging
from typing import Any
from flask_restx import reqparse from pydantic import BaseModel
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
from controllers.common.schema import register_schema_model
from controllers.console.app.error import ( from controllers.console.app.error import (
CompletionRequestError, CompletionRequestError,
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
@@ -32,8 +34,17 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
register_schema_model(console_ns, WorkflowRunPayload)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run") @console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource): class InstalledAppWorkflowRunApi(InstalledAppResource):
@console_ns.expect(console_ns.models[WorkflowRunPayload.__name__])
def post(self, installed_app: InstalledApp): def post(self, installed_app: InstalledApp):
""" """
Run workflow Run workflow
@@ -46,12 +57,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
parser = ( payload = WorkflowRunPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump(exclude_none=True)
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True

View File

@@ -1,29 +1,38 @@
from flask_restx import Resource, reqparse from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_model
from controllers.console.wraps import setup_required from controllers.console.wraps import setup_required
from controllers.inner_api import inner_api_ns from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only
from tasks.mail_inner_task import send_inner_email_task from tasks.mail_inner_task import send_inner_email_task
_mail_parser = (
reqparse.RequestParser() class InnerMailPayload(BaseModel):
.add_argument("to", type=str, action="append", required=True) to: list[str] = Field(description="Recipient email addresses", min_length=1)
.add_argument("subject", type=str, required=True) subject: str
.add_argument("body", type=str, required=True) body: str
.add_argument("substitutions", type=dict, required=False) substitutions: dict[str, Any] | None = None
)
register_schema_model(inner_api_ns, InnerMailPayload)
class BaseMail(Resource): class BaseMail(Resource):
"""Shared logic for sending an inner email.""" """Shared logic for sending an inner email."""
@inner_api_ns.doc("send_inner_mail")
@inner_api_ns.doc(description="Send internal email")
@inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__])
def post(self): def post(self):
args = _mail_parser.parse_args() args = InnerMailPayload.model_validate(inner_api_ns.payload or {})
send_inner_email_task.delay( # type: ignore send_inner_email_task.delay(
to=args["to"], to=args.to,
subject=args["subject"], subject=args.subject,
body=args["body"], body=args.body,
substitutions=args["substitutions"], substitutions=args.substitutions, # type: ignore
) )
return {"message": "success"}, 200 return {"message": "success"}, 200
@@ -34,7 +43,7 @@ class EnterpriseMail(BaseMail):
@inner_api_ns.doc("send_enterprise_mail") @inner_api_ns.doc("send_enterprise_mail")
@inner_api_ns.doc(description="Send internal email for enterprise features") @inner_api_ns.doc(description="Send internal email for enterprise features")
@inner_api_ns.expect(_mail_parser) @inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__])
@inner_api_ns.doc( @inner_api_ns.doc(
responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"} responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"}
) )
@@ -56,7 +65,7 @@ class BillingMail(BaseMail):
@inner_api_ns.doc("send_billing_mail") @inner_api_ns.doc("send_billing_mail")
@inner_api_ns.doc(description="Send internal email for billing notifications") @inner_api_ns.doc(description="Send internal email for billing notifications")
@inner_api_ns.expect(_mail_parser) @inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__])
@inner_api_ns.doc( @inner_api_ns.doc(
responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"} responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"}
) )

View File

@@ -1,10 +1,9 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar, cast from typing import ParamSpec, TypeVar
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
from flask_restx import reqparse
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -17,6 +16,11 @@ P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
class TenantUserPayload(BaseModel):
tenant_id: str
user_id: str
def get_user(tenant_id: str, user_id: str | None) -> EndUser: def get_user(tenant_id: str, user_id: str | None) -> EndUser:
""" """
Get current user Get current user
@@ -67,58 +71,45 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
return user_model return user_model
def get_user_tenant(view: Callable[P, R] | None = None): def get_user_tenant(view_func: Callable[P, R]):
def decorator(view_func: Callable[P, R]): @wraps(view_func)
@wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs): payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
# fetch json body
parser = (
reqparse.RequestParser()
.add_argument("tenant_id", type=str, required=True, location="json")
.add_argument("user_id", type=str, required=True, location="json")
)
p = parser.parse_args() user_id = payload.user_id
tenant_id = payload.tenant_id
user_id = cast(str, p.get("user_id")) if not tenant_id:
tenant_id = cast(str, p.get("tenant_id")) raise ValueError("tenant_id is required")
if not tenant_id: if not user_id:
raise ValueError("tenant_id is required") user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
if not user_id: try:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID tenant_model = (
db.session.query(Tenant)
try: .where(
tenant_model = ( Tenant.id == tenant_id,
db.session.query(Tenant)
.where(
Tenant.id == tenant_id,
)
.first()
) )
except Exception: .first()
raise ValueError("tenant not found") )
except Exception:
raise ValueError("tenant not found")
if not tenant_model: if not tenant_model:
raise ValueError("tenant not found") raise ValueError("tenant not found")
kwargs["tenant_model"] = tenant_model kwargs["tenant_model"] = tenant_model
user = get_user(tenant_id, user_id) user = get_user(tenant_id, user_id)
kwargs["user_model"] = user kwargs["user_model"] = user
current_app.login_manager._update_request_context_with_user(user) # type: ignore current_app.login_manager._update_request_context_with_user(user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
return view_func(*args, **kwargs) return view_func(*args, **kwargs)
return decorated_view return decorated_view
if view is None:
return decorator
else:
return decorator(view)
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]): def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):

View File

@@ -1,7 +1,9 @@
import json import json
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel
from controllers.common.schema import register_schema_models
from controllers.console.wraps import setup_required from controllers.console.wraps import setup_required
from controllers.inner_api import inner_api_ns from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import enterprise_inner_api_only from controllers.inner_api.wraps import enterprise_inner_api_only
@@ -11,12 +13,25 @@ from models import Account
from services.account_service import TenantService from services.account_service import TenantService
class WorkspaceCreatePayload(BaseModel):
name: str
owner_email: str
class WorkspaceOwnerlessPayload(BaseModel):
name: str
register_schema_models(inner_api_ns, WorkspaceCreatePayload, WorkspaceOwnerlessPayload)
@inner_api_ns.route("/enterprise/workspace") @inner_api_ns.route("/enterprise/workspace")
class EnterpriseWorkspace(Resource): class EnterpriseWorkspace(Resource):
@setup_required @setup_required
@enterprise_inner_api_only @enterprise_inner_api_only
@inner_api_ns.doc("create_enterprise_workspace") @inner_api_ns.doc("create_enterprise_workspace")
@inner_api_ns.doc(description="Create a new enterprise workspace with owner assignment") @inner_api_ns.doc(description="Create a new enterprise workspace with owner assignment")
@inner_api_ns.expect(inner_api_ns.models[WorkspaceCreatePayload.__name__])
@inner_api_ns.doc( @inner_api_ns.doc(
responses={ responses={
200: "Workspace created successfully", 200: "Workspace created successfully",
@@ -25,18 +40,13 @@ class EnterpriseWorkspace(Resource):
} }
) )
def post(self): def post(self):
parser = ( args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {})
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("owner_email", type=str, required=True, location="json")
)
args = parser.parse_args()
account = db.session.query(Account).filter_by(email=args["owner_email"]).first() account = db.session.query(Account).filter_by(email=args.owner_email).first()
if account is None: if account is None:
return {"message": "owner account not found."}, 404 return {"message": "owner account not found."}, 404
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True) tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(tenant, account, role="owner")
tenant_was_created.send(tenant) tenant_was_created.send(tenant)
@@ -62,6 +72,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
@enterprise_inner_api_only @enterprise_inner_api_only
@inner_api_ns.doc("create_enterprise_workspace_ownerless") @inner_api_ns.doc("create_enterprise_workspace_ownerless")
@inner_api_ns.doc(description="Create a new enterprise workspace without initial owner assignment") @inner_api_ns.doc(description="Create a new enterprise workspace without initial owner assignment")
@inner_api_ns.expect(inner_api_ns.models[WorkspaceOwnerlessPayload.__name__])
@inner_api_ns.doc( @inner_api_ns.doc(
responses={ responses={
200: "Workspace created successfully", 200: "Workspace created successfully",
@@ -70,10 +81,9 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
} }
) )
def post(self): def post(self):
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json") args = WorkspaceOwnerlessPayload.model_validate(inner_api_ns.payload or {})
args = parser.parse_args()
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True) tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
tenant_was_created.send(tenant) tenant_was_created.send(tenant)

View File

@@ -1,10 +1,11 @@
from typing import Union from typing import Any, Union
from flask import Response from flask import Response
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import ValidationError from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.app.mcp_server import AppMCPServerStatus from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns from controllers.mcp import mcp_ns
from core.app.app_config.entities import VariableEntity from core.app.app_config.entities import VariableEntity
@@ -24,27 +25,19 @@ class MCPRequestError(Exception):
super().__init__(message) super().__init__(message)
def int_or_str(value): class MCPRequestPayload(BaseModel):
"""Validate that a value is either an integer or string.""" jsonrpc: str = Field(description="JSON-RPC version (should be '2.0')")
if isinstance(value, (int, str)): method: str = Field(description="The method to invoke")
return value params: dict[str, Any] | None = Field(default=None, description="Parameters for the method")
else: id: int | str | None = Field(default=None, description="Request ID for tracking responses")
return None
# Define parser for both documentation and validation register_schema_model(mcp_ns, MCPRequestPayload)
mcp_request_parser = (
reqparse.RequestParser()
.add_argument("jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')")
.add_argument("method", type=str, required=True, location="json", help="The method to invoke")
.add_argument("params", type=dict, required=False, location="json", help="Parameters for the method")
.add_argument("id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses")
)
@mcp_ns.route("/server/<string:server_code>/mcp") @mcp_ns.route("/server/<string:server_code>/mcp")
class MCPAppApi(Resource): class MCPAppApi(Resource):
@mcp_ns.expect(mcp_request_parser) @mcp_ns.expect(mcp_ns.models[MCPRequestPayload.__name__])
@mcp_ns.doc("handle_mcp_request") @mcp_ns.doc("handle_mcp_request")
@mcp_ns.doc(description="Handle Model Context Protocol (MCP) requests for a specific server") @mcp_ns.doc(description="Handle Model Context Protocol (MCP) requests for a specific server")
@mcp_ns.doc(params={"server_code": "Unique identifier for the MCP server"}) @mcp_ns.doc(params={"server_code": "Unique identifier for the MCP server"})
@@ -70,9 +63,9 @@ class MCPAppApi(Resource):
Raises: Raises:
ValidationError: Invalid request format or parameters ValidationError: Invalid request format or parameters
""" """
args = mcp_request_parser.parse_args() args = MCPRequestPayload.model_validate(mcp_ns.payload or {})
request_id: Union[int, str] | None = args.get("id") request_id: Union[int, str] | None = args.id
mcp_request = self._parse_mcp_request(args) mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True))
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
# Get MCP server and app # Get MCP server and app

View File

@@ -1,9 +1,11 @@
from typing import Literal from typing import Literal
from flask import request from flask import request
from flask_restx import Api, Namespace, Resource, fields, reqparse from flask_restx import Api, Namespace, Resource, fields
from flask_restx.api import HTTPStatus from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console.wraps import edit_permission_required from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token from controllers.service_api.wraps import validate_app_token
@@ -12,26 +14,24 @@ from fields.annotation_fields import annotation_fields, build_annotation_model
from models.model import App from models.model import App
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
# Define parsers for annotation API
annotation_create_parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json", help="Annotation question")
.add_argument("answer", required=True, type=str, location="json", help="Annotation answer")
)
annotation_reply_action_parser = ( class AnnotationCreatePayload(BaseModel):
reqparse.RequestParser() question: str = Field(description="Annotation question")
.add_argument( answer: str = Field(description="Annotation answer")
"score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching"
)
.add_argument("embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name") class AnnotationReplyActionPayload(BaseModel):
.add_argument("embedding_model_name", required=True, type=str, location="json", help="Embedding model name") score_threshold: float = Field(description="Score threshold for annotation matching")
) embedding_provider_name: str = Field(description="Embedding provider name")
embedding_model_name: str = Field(description="Embedding model name")
register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload)
@service_api_ns.route("/apps/annotation-reply/<string:action>") @service_api_ns.route("/apps/annotation-reply/<string:action>")
class AnnotationReplyActionApi(Resource): class AnnotationReplyActionApi(Resource):
@service_api_ns.expect(annotation_reply_action_parser) @service_api_ns.expect(service_api_ns.models[AnnotationReplyActionPayload.__name__])
@service_api_ns.doc("annotation_reply_action") @service_api_ns.doc("annotation_reply_action")
@service_api_ns.doc(description="Enable or disable annotation reply feature") @service_api_ns.doc(description="Enable or disable annotation reply feature")
@service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"}) @service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"})
@@ -44,7 +44,7 @@ class AnnotationReplyActionApi(Resource):
@validate_app_token @validate_app_token
def post(self, app_model: App, action: Literal["enable", "disable"]): def post(self, app_model: App, action: Literal["enable", "disable"]):
"""Enable or disable annotation reply feature.""" """Enable or disable annotation reply feature."""
args = annotation_reply_action_parser.parse_args() args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
if action == "enable": if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id) result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable": elif action == "disable":
@@ -126,7 +126,7 @@ class AnnotationListApi(Resource):
"page": page, "page": page,
} }
@service_api_ns.expect(annotation_create_parser) @service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
@service_api_ns.doc("create_annotation") @service_api_ns.doc("create_annotation")
@service_api_ns.doc(description="Create a new annotation") @service_api_ns.doc(description="Create a new annotation")
@service_api_ns.doc( @service_api_ns.doc(
@@ -139,14 +139,14 @@ class AnnotationListApi(Resource):
@service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED) @service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
def post(self, app_model: App): def post(self, app_model: App):
"""Create a new annotation.""" """Create a new annotation."""
args = annotation_create_parser.parse_args() args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
return annotation, 201 return annotation, 201
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>") @service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource): class AnnotationUpdateDeleteApi(Resource):
@service_api_ns.expect(annotation_create_parser) @service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
@service_api_ns.doc("update_annotation") @service_api_ns.doc("update_annotation")
@service_api_ns.doc(description="Update an existing annotation") @service_api_ns.doc(description="Update an existing annotation")
@service_api_ns.doc(params={"annotation_id": "Annotation ID"}) @service_api_ns.doc(params={"annotation_id": "Annotation ID"})
@@ -163,7 +163,7 @@ class AnnotationUpdateDeleteApi(Resource):
@service_api_ns.marshal_with(build_annotation_model(service_api_ns)) @service_api_ns.marshal_with(build_annotation_model(service_api_ns))
def put(self, app_model: App, annotation_id: str): def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation.""" """Update an existing annotation."""
args = annotation_create_parser.parse_args() args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation return annotation

View File

@@ -1,10 +1,12 @@
import logging import logging
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
AppUnavailableError, AppUnavailableError,
@@ -84,19 +86,19 @@ class AudioApi(Resource):
raise InternalServerError() raise InternalServerError()
# Define parser for text-to-audio API class TextToAudioPayload(BaseModel):
text_to_audio_parser = ( message_id: str | None = Field(default=None, description="Message ID")
reqparse.RequestParser() voice: str | None = Field(default=None, description="Voice to use for TTS")
.add_argument("message_id", type=str, required=False, location="json", help="Message ID") text: str | None = Field(default=None, description="Text to convert to audio")
.add_argument("voice", type=str, location="json", help="Voice to use for TTS") streaming: bool | None = Field(default=None, description="Enable streaming response")
.add_argument("text", type=str, location="json", help="Text to convert to audio")
.add_argument("streaming", type=bool, location="json", help="Enable streaming response")
) register_schema_model(service_api_ns, TextToAudioPayload)
@service_api_ns.route("/text-to-audio") @service_api_ns.route("/text-to-audio")
class TextApi(Resource): class TextApi(Resource):
@service_api_ns.expect(text_to_audio_parser) @service_api_ns.expect(service_api_ns.models[TextToAudioPayload.__name__])
@service_api_ns.doc("text_to_audio") @service_api_ns.doc("text_to_audio")
@service_api_ns.doc(description="Convert text to audio using text-to-speech") @service_api_ns.doc(description="Convert text to audio using text-to-speech")
@service_api_ns.doc( @service_api_ns.doc(
@@ -114,11 +116,11 @@ class TextApi(Resource):
Converts the provided text to audio using the specified voice. Converts the provided text to audio using the specified voice.
""" """
try: try:
args = text_to_audio_parser.parse_args() payload = TextToAudioPayload.model_validate(service_api_ns.payload or {})
message_id = args.get("message_id", None) message_id = payload.message_id
text = args.get("text", None) text = payload.text
voice = args.get("voice", None) voice = payload.voice
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
) )

View File

@@ -1,10 +1,14 @@
import logging import logging
from typing import Any, Literal
from uuid import UUID
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
AppUnavailableError, AppUnavailableError,
@@ -26,7 +30,6 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService from services.app_task_service import AppTaskService
@@ -36,40 +39,31 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define parser for completion API class CompletionRequestPayload(BaseModel):
completion_parser = ( inputs: dict[str, Any]
reqparse.RequestParser() query: str = Field(default="")
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for completion") files: list[dict[str, Any]] | None = None
.add_argument("query", type=str, location="json", default="", help="The query string") response_mode: Literal["blocking", "streaming"] | None = None
.add_argument("files", type=list, required=False, location="json", help="List of file attachments") retriever_from: str = Field(default="dev")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
)
# Define parser for chat API
chat_parser = ( class ChatRequestPayload(BaseModel):
reqparse.RequestParser() inputs: dict[str, Any]
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat") query: str
.add_argument("query", type=str, required=True, location="json", help="The chat query") files: list[dict[str, Any]] | None = None
.add_argument("files", type=list, required=False, location="json", help="List of file attachments") response_mode: Literal["blocking", "streaming"] | None = None
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode") conversation_id: UUID | None = None
.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID") retriever_from: str = Field(default="dev")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source") auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
.add_argument( workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
"auto_generate_name",
type=bool,
required=False, register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)
default=True,
location="json",
help="Auto generate conversation name",
)
.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
)
@service_api_ns.route("/completion-messages") @service_api_ns.route("/completion-messages")
class CompletionApi(Resource): class CompletionApi(Resource):
@service_api_ns.expect(completion_parser) @service_api_ns.expect(service_api_ns.models[CompletionRequestPayload.__name__])
@service_api_ns.doc("create_completion") @service_api_ns.doc("create_completion")
@service_api_ns.doc(description="Create a completion for the given prompt") @service_api_ns.doc(description="Create a completion for the given prompt")
@service_api_ns.doc( @service_api_ns.doc(
@@ -91,12 +85,13 @@ class CompletionApi(Resource):
if app_model.mode != AppMode.COMPLETION: if app_model.mode != AppMode.COMPLETION:
raise AppUnavailableError() raise AppUnavailableError()
args = completion_parser.parse_args() payload = CompletionRequestPayload.model_validate(service_api_ns.payload or {})
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
args = payload.model_dump(exclude_none=True)
if external_trace_id: if external_trace_id:
args["external_trace_id"] = external_trace_id args["external_trace_id"] = external_trace_id
streaming = args["response_mode"] == "streaming" streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False args["auto_generate_name"] = False
@@ -162,7 +157,7 @@ class CompletionStopApi(Resource):
@service_api_ns.route("/chat-messages") @service_api_ns.route("/chat-messages")
class ChatApi(Resource): class ChatApi(Resource):
@service_api_ns.expect(chat_parser) @service_api_ns.expect(service_api_ns.models[ChatRequestPayload.__name__])
@service_api_ns.doc("create_chat_message") @service_api_ns.doc("create_chat_message")
@service_api_ns.doc(description="Send a message in a chat conversation") @service_api_ns.doc(description="Send a message in a chat conversation")
@service_api_ns.doc( @service_api_ns.doc(
@@ -186,13 +181,14 @@ class ChatApi(Resource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
args = chat_parser.parse_args() payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
args = payload.model_dump(exclude_none=True)
if external_trace_id: if external_trace_id:
args["external_trace_id"] = external_trace_id args["external_trace_id"] = external_trace_id
streaming = args["response_mode"] == "streaming" streaming = payload.response_mode == "streaming"
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(

View File

@@ -1,10 +1,15 @@
from flask_restx import Resource, reqparse from typing import Any, Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
from flask_restx._http import HTTPStatus from flask_restx._http import HTTPStatus
from flask_restx.inputs import int_range from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
@@ -19,74 +24,44 @@ from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model, build_conversation_variable_infinite_scroll_pagination_model,
build_conversation_variable_model, build_conversation_variable_model,
) )
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
# Define parsers for conversation APIs
conversation_list_parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args", help="Last conversation ID for pagination")
.add_argument(
"limit",
type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of conversations to return",
)
.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
help="Sort order for conversations",
)
)
conversation_rename_parser = ( class ConversationListQuery(BaseModel):
reqparse.RequestParser() last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination")
.add_argument("name", type=str, required=False, location="json", help="New conversation name") limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
.add_argument( sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
"auto_generate", default="-updated_at", description="Sort order for conversations"
type=bool,
required=False,
default=False,
location="json",
help="Auto-generate conversation name",
) )
)
conversation_variables_parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args", help="Last variable ID for pagination")
.add_argument(
"limit",
type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of variables to return",
)
)
conversation_variable_update_parser = reqparse.RequestParser().add_argument( class ConversationRenamePayload(BaseModel):
# using lambda is for passing the already-typed value without modification name: str = Field(description="New conversation name")
# if no lambda, it will be converted to string auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
# the string cannot be converted using json.loads
"value",
required=True, class ConversationVariablesQuery(BaseModel):
location="json", last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
type=lambda x: x, limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
help="New value for the conversation variable",
class ConversationVariableUpdatePayload(BaseModel):
value: Any
register_schema_models(
service_api_ns,
ConversationListQuery,
ConversationRenamePayload,
ConversationVariablesQuery,
ConversationVariableUpdatePayload,
) )
@service_api_ns.route("/conversations") @service_api_ns.route("/conversations")
class ConversationApi(Resource): class ConversationApi(Resource):
@service_api_ns.expect(conversation_list_parser) @service_api_ns.expect(service_api_ns.models[ConversationListQuery.__name__])
@service_api_ns.doc("list_conversations") @service_api_ns.doc("list_conversations")
@service_api_ns.doc(description="List all conversations for the current user") @service_api_ns.doc(description="List all conversations for the current user")
@service_api_ns.doc( @service_api_ns.doc(
@@ -107,7 +82,8 @@ class ConversationApi(Resource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
args = conversation_list_parser.parse_args() query_args = ConversationListQuery.model_validate(request.args.to_dict())
last_id = str(query_args.last_id) if query_args.last_id else None
try: try:
with Session(db.engine) as session: with Session(db.engine) as session:
@@ -115,10 +91,10 @@ class ConversationApi(Resource):
session=session, session=session,
app_model=app_model, app_model=app_model,
user=end_user, user=end_user,
last_id=args["last_id"], last_id=last_id,
limit=args["limit"], limit=query_args.limit,
invoke_from=InvokeFrom.SERVICE_API, invoke_from=InvokeFrom.SERVICE_API,
sort_by=args["sort_by"], sort_by=query_args.sort_by,
) )
except services.errors.conversation.LastConversationNotExistsError: except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")
@@ -155,7 +131,7 @@ class ConversationDetailApi(Resource):
@service_api_ns.route("/conversations/<uuid:c_id>/name") @service_api_ns.route("/conversations/<uuid:c_id>/name")
class ConversationRenameApi(Resource): class ConversationRenameApi(Resource):
@service_api_ns.expect(conversation_rename_parser) @service_api_ns.expect(service_api_ns.models[ConversationRenamePayload.__name__])
@service_api_ns.doc("rename_conversation") @service_api_ns.doc("rename_conversation")
@service_api_ns.doc(description="Rename a conversation or auto-generate a name") @service_api_ns.doc(description="Rename a conversation or auto-generate a name")
@service_api_ns.doc(params={"c_id": "Conversation ID"}) @service_api_ns.doc(params={"c_id": "Conversation ID"})
@@ -176,17 +152,17 @@ class ConversationRenameApi(Resource):
conversation_id = str(c_id) conversation_id = str(c_id)
args = conversation_rename_parser.parse_args() payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {})
try: try:
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@service_api_ns.route("/conversations/<uuid:c_id>/variables") @service_api_ns.route("/conversations/<uuid:c_id>/variables")
class ConversationVariablesApi(Resource): class ConversationVariablesApi(Resource):
@service_api_ns.expect(conversation_variables_parser) @service_api_ns.expect(service_api_ns.models[ConversationVariablesQuery.__name__])
@service_api_ns.doc("list_conversation_variables") @service_api_ns.doc("list_conversation_variables")
@service_api_ns.doc(description="List all variables for a conversation") @service_api_ns.doc(description="List all variables for a conversation")
@service_api_ns.doc(params={"c_id": "Conversation ID"}) @service_api_ns.doc(params={"c_id": "Conversation ID"})
@@ -211,11 +187,12 @@ class ConversationVariablesApi(Resource):
conversation_id = str(c_id) conversation_id = str(c_id)
args = conversation_variables_parser.parse_args() query_args = ConversationVariablesQuery.model_validate(request.args.to_dict())
last_id = str(query_args.last_id) if query_args.last_id else None
try: try:
return ConversationService.get_conversational_variable( return ConversationService.get_conversational_variable(
app_model, conversation_id, end_user, args["limit"], args["last_id"] app_model, conversation_id, end_user, query_args.limit, last_id
) )
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@@ -223,7 +200,7 @@ class ConversationVariablesApi(Resource):
@service_api_ns.route("/conversations/<uuid:c_id>/variables/<uuid:variable_id>") @service_api_ns.route("/conversations/<uuid:c_id>/variables/<uuid:variable_id>")
class ConversationVariableDetailApi(Resource): class ConversationVariableDetailApi(Resource):
@service_api_ns.expect(conversation_variable_update_parser) @service_api_ns.expect(service_api_ns.models[ConversationVariableUpdatePayload.__name__])
@service_api_ns.doc("update_conversation_variable") @service_api_ns.doc("update_conversation_variable")
@service_api_ns.doc(description="Update a conversation variable's value") @service_api_ns.doc(description="Update a conversation variable's value")
@service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"}) @service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"})
@@ -250,11 +227,11 @@ class ConversationVariableDetailApi(Resource):
conversation_id = str(c_id) conversation_id = str(c_id)
variable_id = str(variable_id) variable_id = str(variable_id)
args = conversation_variable_update_parser.parse_args() payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
try: try:
return ConversationService.update_conversation_variable( return ConversationService.update_conversation_variable(
app_model, conversation_id, variable_id, end_user, args["value"] app_model, conversation_id, variable_id, end_user, payload.value
) )
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")

View File

@@ -1,9 +1,11 @@
import logging import logging
from urllib.parse import quote from urllib.parse import quote
from flask import Response from flask import Response, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
FileAccessDeniedError, FileAccessDeniedError,
@@ -17,10 +19,11 @@ from models.model import App, EndUser, Message, MessageFile, UploadFile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define parser for file preview API class FilePreviewQuery(BaseModel):
file_preview_parser = reqparse.RequestParser().add_argument( as_attachment: bool = Field(default=False, description="Download as attachment")
"as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment"
)
register_schema_model(service_api_ns, FilePreviewQuery)
@service_api_ns.route("/files/<uuid:file_id>/preview") @service_api_ns.route("/files/<uuid:file_id>/preview")
@@ -32,7 +35,7 @@ class FilePreviewApi(Resource):
Files can only be accessed if they belong to messages within the requesting app's context. Files can only be accessed if they belong to messages within the requesting app's context.
""" """
@service_api_ns.expect(file_preview_parser) @service_api_ns.expect(service_api_ns.models[FilePreviewQuery.__name__])
@service_api_ns.doc("preview_file") @service_api_ns.doc("preview_file")
@service_api_ns.doc(description="Preview or download a file uploaded via Service API") @service_api_ns.doc(description="Preview or download a file uploaded via Service API")
@service_api_ns.doc(params={"file_id": "UUID of the file to preview"}) @service_api_ns.doc(params={"file_id": "UUID of the file to preview"})
@@ -55,7 +58,7 @@ class FilePreviewApi(Resource):
file_id = str(file_id) file_id = str(file_id)
# Parse query parameters # Parse query parameters
args = file_preview_parser.parse_args() args = FilePreviewQuery.model_validate(request.args.to_dict())
# Validate file ownership and get file objects # Validate file ownership and get file objects
_, upload_file = self._validate_file_ownership(file_id, app_model.id) _, upload_file = self._validate_file_ownership(file_id, app_model.id)
@@ -67,7 +70,7 @@ class FilePreviewApi(Resource):
raise FileNotFoundError(f"Failed to load file content: {str(e)}") raise FileNotFoundError(f"Failed to load file content: {str(e)}")
# Build response with appropriate headers # Build response with appropriate headers
response = self._build_file_response(generator, upload_file, args["as_attachment"]) response = self._build_file_response(generator, upload_file, args.as_attachment)
return response return response

View File

@@ -1,11 +1,15 @@
import json import json
import logging import logging
from typing import Literal
from uuid import UUID
from flask_restx import Api, Namespace, Resource, fields, reqparse from flask import request
from flask_restx.inputs import int_range from flask_restx import Namespace, Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
@@ -13,7 +17,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import build_message_file_model from fields.conversation_fields import build_message_file_model
from fields.message_fields import build_agent_thought_model, build_feedback_model from fields.message_fields import build_agent_thought_model, build_feedback_model
from fields.raws import FilesContainedField from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from services.errors.message import ( from services.errors.message import (
FirstMessageNotExistsError, FirstMessageNotExistsError,
@@ -25,42 +29,26 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define parsers for message APIs class MessageListQuery(BaseModel):
message_list_parser = ( conversation_id: UUID
reqparse.RequestParser() first_id: UUID | None = None
.add_argument("conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID") limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination")
.add_argument(
"limit",
type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of messages to return",
)
)
message_feedback_parser = (
reqparse.RequestParser()
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating")
.add_argument("content", type=str, location="json", help="Feedback content")
)
feedback_list_parser = (
reqparse.RequestParser()
.add_argument("page", type=int, default=1, location="args", help="Page number")
.add_argument(
"limit",
type=int_range(1, 101),
required=False,
default=20,
location="args",
help="Number of feedbacks per page",
)
)
def build_message_model(api_or_ns: Api | Namespace): class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class FeedbackListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")
register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery)
def build_message_model(api_or_ns: Namespace):
"""Build the message model for the API or Namespace.""" """Build the message model for the API or Namespace."""
# First build the nested models # First build the nested models
feedback_model = build_feedback_model(api_or_ns) feedback_model = build_feedback_model(api_or_ns)
@@ -90,7 +78,7 @@ def build_message_model(api_or_ns: Api | Namespace):
return api_or_ns.model("Message", message_fields) return api_or_ns.model("Message", message_fields)
def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the message infinite scroll pagination model for the API or Namespace.""" """Build the message infinite scroll pagination model for the API or Namespace."""
# Build the nested message model first # Build the nested message model first
message_model = build_message_model(api_or_ns) message_model = build_message_model(api_or_ns)
@@ -105,7 +93,7 @@ def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
@service_api_ns.route("/messages") @service_api_ns.route("/messages")
class MessageListApi(Resource): class MessageListApi(Resource):
@service_api_ns.expect(message_list_parser) @service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__])
@service_api_ns.doc("list_messages") @service_api_ns.doc("list_messages")
@service_api_ns.doc(description="List messages in a conversation") @service_api_ns.doc(description="List messages in a conversation")
@service_api_ns.doc( @service_api_ns.doc(
@@ -126,11 +114,13 @@ class MessageListApi(Resource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
args = message_list_parser.parse_args() query_args = MessageListQuery.model_validate(request.args.to_dict())
conversation_id = str(query_args.conversation_id)
first_id = str(query_args.first_id) if query_args.first_id else None
try: try:
return MessageService.pagination_by_first_id( return MessageService.pagination_by_first_id(
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] app_model, end_user, conversation_id, first_id, query_args.limit
) )
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@@ -140,7 +130,7 @@ class MessageListApi(Resource):
@service_api_ns.route("/messages/<uuid:message_id>/feedbacks") @service_api_ns.route("/messages/<uuid:message_id>/feedbacks")
class MessageFeedbackApi(Resource): class MessageFeedbackApi(Resource):
@service_api_ns.expect(message_feedback_parser) @service_api_ns.expect(service_api_ns.models[MessageFeedbackPayload.__name__])
@service_api_ns.doc("create_message_feedback") @service_api_ns.doc("create_message_feedback")
@service_api_ns.doc(description="Submit feedback for a message") @service_api_ns.doc(description="Submit feedback for a message")
@service_api_ns.doc(params={"message_id": "Message ID"}) @service_api_ns.doc(params={"message_id": "Message ID"})
@@ -159,15 +149,15 @@ class MessageFeedbackApi(Resource):
""" """
message_id = str(message_id) message_id = str(message_id)
args = message_feedback_parser.parse_args() payload = MessageFeedbackPayload.model_validate(service_api_ns.payload or {})
try: try:
MessageService.create_feedback( MessageService.create_feedback(
app_model=app_model, app_model=app_model,
message_id=message_id, message_id=message_id,
user=end_user, user=end_user,
rating=args.get("rating"), rating=payload.rating,
content=args.get("content"), content=payload.content,
) )
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@@ -177,7 +167,7 @@ class MessageFeedbackApi(Resource):
@service_api_ns.route("/app/feedbacks") @service_api_ns.route("/app/feedbacks")
class AppGetFeedbacksApi(Resource): class AppGetFeedbacksApi(Resource):
@service_api_ns.expect(feedback_list_parser) @service_api_ns.expect(service_api_ns.models[FeedbackListQuery.__name__])
@service_api_ns.doc("get_app_feedbacks") @service_api_ns.doc("get_app_feedbacks")
@service_api_ns.doc(description="Get all feedbacks for the application") @service_api_ns.doc(description="Get all feedbacks for the application")
@service_api_ns.doc( @service_api_ns.doc(
@@ -192,8 +182,8 @@ class AppGetFeedbacksApi(Resource):
Returns paginated list of all feedback submitted for messages in this app. Returns paginated list of all feedback submitted for messages in this app.
""" """
args = feedback_list_parser.parse_args() query_args = FeedbackListQuery.model_validate(request.args.to_dict())
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"]) feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=query_args.page, limit=query_args.limit)
return {"data": feedbacks} return {"data": feedbacks}

View File

@@ -1,12 +1,14 @@
import logging import logging
from typing import Any, Literal
from dateutil.parser import isoparse from dateutil.parser import isoparse
from flask import request from flask import request
from flask_restx import Api, Namespace, Resource, fields, reqparse from flask_restx import Api, Namespace, Resource, fields
from flask_restx.inputs import int_range from pydantic import BaseModel, Field
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
CompletionRequestError, CompletionRequestError,
@@ -41,37 +43,25 @@ from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define parsers for workflow APIs
workflow_run_parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
)
workflow_log_parser = ( class WorkflowRunPayload(BaseModel):
reqparse.RequestParser() inputs: dict[str, Any]
.add_argument("keyword", type=str, location="args") files: list[dict[str, Any]] | None = None
.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") response_mode: Literal["blocking", "streaming"] | None = None
.add_argument("created_at__before", type=str, location="args")
.add_argument("created_at__after", type=str, location="args")
.add_argument( class WorkflowLogQuery(BaseModel):
"created_by_end_user_session_id", keyword: str | None = None
type=str, status: Literal["succeeded", "failed", "stopped"] | None = None
location="args", created_at__before: str | None = None
required=False, created_at__after: str | None = None
default=None, created_by_end_user_session_id: str | None = None
) created_by_account: str | None = None
.add_argument( page: int = Field(default=1, ge=1, le=99999)
"created_by_account", limit: int = Field(default=20, ge=1, le=100)
type=str,
location="args",
required=False, register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
default=None,
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
workflow_run_fields = { workflow_run_fields = {
"id": fields.String, "id": fields.String,
@@ -130,7 +120,7 @@ class WorkflowRunDetailApi(Resource):
@service_api_ns.route("/workflows/run") @service_api_ns.route("/workflows/run")
class WorkflowRunApi(Resource): class WorkflowRunApi(Resource):
@service_api_ns.expect(workflow_run_parser) @service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
@service_api_ns.doc("run_workflow") @service_api_ns.doc("run_workflow")
@service_api_ns.doc(description="Execute a workflow") @service_api_ns.doc(description="Execute a workflow")
@service_api_ns.doc( @service_api_ns.doc(
@@ -154,11 +144,12 @@ class WorkflowRunApi(Resource):
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
args = workflow_run_parser.parse_args() payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
if external_trace_id: if external_trace_id:
args["external_trace_id"] = external_trace_id args["external_trace_id"] = external_trace_id
streaming = args.get("response_mode") == "streaming" streaming = payload.response_mode == "streaming"
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
@@ -185,7 +176,7 @@ class WorkflowRunApi(Resource):
@service_api_ns.route("/workflows/<string:workflow_id>/run") @service_api_ns.route("/workflows/<string:workflow_id>/run")
class WorkflowRunByIdApi(Resource): class WorkflowRunByIdApi(Resource):
@service_api_ns.expect(workflow_run_parser) @service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
@service_api_ns.doc("run_workflow_by_id") @service_api_ns.doc("run_workflow_by_id")
@service_api_ns.doc(description="Execute a specific workflow by ID") @service_api_ns.doc(description="Execute a specific workflow by ID")
@service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"}) @service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"})
@@ -209,7 +200,8 @@ class WorkflowRunByIdApi(Resource):
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
args = workflow_run_parser.parse_args() payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
# Add workflow_id to args for AppGenerateService # Add workflow_id to args for AppGenerateService
args["workflow_id"] = workflow_id args["workflow_id"] = workflow_id
@@ -217,7 +209,7 @@ class WorkflowRunByIdApi(Resource):
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
if external_trace_id: if external_trace_id:
args["external_trace_id"] = external_trace_id args["external_trace_id"] = external_trace_id
streaming = args.get("response_mode") == "streaming" streaming = payload.response_mode == "streaming"
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
@@ -279,7 +271,7 @@ class WorkflowTaskStopApi(Resource):
@service_api_ns.route("/workflows/logs") @service_api_ns.route("/workflows/logs")
class WorkflowAppLogApi(Resource): class WorkflowAppLogApi(Resource):
@service_api_ns.expect(workflow_log_parser) @service_api_ns.expect(service_api_ns.models[WorkflowLogQuery.__name__])
@service_api_ns.doc("get_workflow_logs") @service_api_ns.doc("get_workflow_logs")
@service_api_ns.doc(description="Get workflow execution logs") @service_api_ns.doc(description="Get workflow execution logs")
@service_api_ns.doc( @service_api_ns.doc(
@@ -295,14 +287,11 @@ class WorkflowAppLogApi(Resource):
Returns paginated workflow execution logs with filtering options. Returns paginated workflow execution logs with filtering options.
""" """
args = workflow_log_parser.parse_args() args = WorkflowLogQuery.model_validate(request.args.to_dict())
args.status = WorkflowExecutionStatus(args.status) if args.status else None status = WorkflowExecutionStatus(args.status) if args.status else None
if args.created_at__before: created_at_before = isoparse(args.created_at__before) if args.created_at__before else None
args.created_at__before = isoparse(args.created_at__before) created_at_after = isoparse(args.created_at__after) if args.created_at__after else None
if args.created_at__after:
args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()
@@ -311,9 +300,9 @@ class WorkflowAppLogApi(Resource):
session=session, session=session,
app_model=app_model, app_model=app_model,
keyword=args.keyword, keyword=args.keyword,
status=args.status, status=status,
created_at_before=args.created_at__before, created_at_before=created_at_before,
created_at_after=args.created_at__after, created_at_after=created_at_after,
page=args.page, page=args.page,
limit=args.limit, limit=args.limit,
created_by_end_user_session_id=args.created_by_end_user_session_id, created_by_end_user_session_id=args.created_by_end_user_session_id,

View File

@@ -1,10 +1,12 @@
from typing import Any, Literal, cast from typing import Any, Literal, cast
from flask import request from flask import request
from flask_restx import marshal, reqparse from flask_restx import marshal
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console.wraps import edit_permission_required from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
@@ -18,173 +20,83 @@ from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields from fields.tag_fields import build_dataset_tag_fields
from libs.login import current_user from libs.login import current_user
from libs.validators import validate_description_length
from models.account import Account from models.account import Account
from models.dataset import Dataset, DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService from services.tag_service import TagService
def _validate_name(name): class DatasetCreatePayload(BaseModel):
if not name or len(name) < 1 or len(name) > 40: name: str = Field(..., min_length=1, max_length=40)
raise ValueError("Name must be between 1 to 40 characters.") description: str = Field(default="", description="Dataset description (max 400 chars)", max_length=400)
return name indexing_technique: Literal["high_quality", "economy"] | None = None
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
external_knowledge_api_id: str | None = None
provider: str = "vendor"
external_knowledge_id: str | None = None
retrieval_model: RetrievalModel | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
# Define parsers for dataset operations class DatasetUpdatePayload(BaseModel):
dataset_create_parser = ( name: str | None = Field(default=None, min_length=1, max_length=40)
reqparse.RequestParser() description: str | None = Field(default=None, description="Dataset description (max 400 chars)", max_length=400)
.add_argument( indexing_technique: Literal["high_quality", "economy"] | None = None
"name", permission: DatasetPermissionEnum | None = None
nullable=False, embedding_model: str | None = None
required=True, embedding_model_provider: str | None = None
help="type is required. Name must be between 1 to 40 characters.", retrieval_model: RetrievalModel | None = None
type=_validate_name, partial_member_list: list[str] | None = None
) external_retrieval_model: dict[str, Any] | None = None
.add_argument( external_knowledge_id: str | None = None
"description", external_knowledge_api_id: str | None = None
type=validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
required=False,
nullable=False,
)
.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
default="_validate_name",
)
.add_argument(
"provider",
type=str,
nullable=True,
required=False,
default="vendor",
)
.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
dataset_update_parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument("embedding_model_provider", type=str, location="json", help="Invalid embedding model provider.")
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
)
tag_create_parser = reqparse.RequestParser().add_argument( class TagNamePayload(BaseModel):
"name", name: str = Field(..., min_length=1, max_length=50)
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=lambda x: x
if x and 1 <= len(x) <= 50
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
)
tag_update_parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=lambda x: x
if x and 1 <= len(x) <= 50
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
)
.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
)
tag_delete_parser = reqparse.RequestParser().add_argument( class TagCreatePayload(TagNamePayload):
"tag_id", nullable=False, required=True, help="Id of a tag.", type=str pass
)
tag_binding_parser = (
reqparse.RequestParser()
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
)
)
tag_unbinding_parser = ( class TagUpdatePayload(TagNamePayload):
reqparse.RequestParser() tag_id: str
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
class TagDeletePayload(BaseModel):
tag_id: str
class TagBindingPayload(BaseModel):
tag_ids: list[str]
target_id: str
@field_validator("tag_ids")
@classmethod
def validate_tag_ids(cls, value: list[str]) -> list[str]:
if not value:
raise ValueError("Tag IDs is required.")
return value
class TagUnbindingPayload(BaseModel):
tag_id: str
target_id: str
register_schema_models(
service_api_ns,
DatasetCreatePayload,
DatasetUpdatePayload,
TagCreatePayload,
TagUpdatePayload,
TagDeletePayload,
TagBindingPayload,
TagUnbindingPayload,
) )
@@ -239,7 +151,7 @@ class DatasetListApi(DatasetApiResource):
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response, 200 return response, 200
@service_api_ns.expect(dataset_create_parser) @service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
@service_api_ns.doc("create_dataset") @service_api_ns.doc("create_dataset")
@service_api_ns.doc(description="Create a new dataset") @service_api_ns.doc(description="Create a new dataset")
@service_api_ns.doc( @service_api_ns.doc(
@@ -252,42 +164,41 @@ class DatasetListApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id): def post(self, tenant_id):
"""Resource for creating datasets.""" """Resource for creating datasets."""
args = dataset_create_parser.parse_args() payload = DatasetCreatePayload.model_validate(service_api_ns.payload or {})
embedding_model_provider = args.get("embedding_model_provider") embedding_model_provider = payload.embedding_model_provider
embedding_model = args.get("embedding_model") embedding_model = payload.embedding_model
if embedding_model_provider and embedding_model: if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model") retrieval_model = payload.retrieval_model
if ( if (
retrieval_model retrieval_model
and retrieval_model.get("reranking_model") and retrieval_model.reranking_model
and retrieval_model.get("reranking_model").get("reranking_provider_name") and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
tenant_id, tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"), retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.get("reranking_model").get("reranking_model_name"), retrieval_model.reranking_model.reranking_model_name,
) )
try: try:
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
name=args["name"], name=payload.name,
description=args["description"], description=payload.description,
indexing_technique=args["indexing_technique"], indexing_technique=payload.indexing_technique,
account=current_user, account=current_user,
permission=args["permission"], permission=str(payload.permission) if payload.permission else None,
provider=args["provider"], provider=payload.provider,
external_knowledge_api_id=args["external_knowledge_api_id"], external_knowledge_api_id=payload.external_knowledge_api_id,
external_knowledge_id=args["external_knowledge_id"], external_knowledge_id=payload.external_knowledge_id,
embedding_model_provider=args["embedding_model_provider"], embedding_model_provider=payload.embedding_model_provider,
embedding_model_name=args["embedding_model"], embedding_model_name=payload.embedding_model,
retrieval_model=RetrievalModel.model_validate(args["retrieval_model"]) retrieval_model=payload.retrieval_model,
if args["retrieval_model"] is not None
else None,
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@@ -353,7 +264,7 @@ class DatasetApi(DatasetApiResource):
return data, 200 return data, 200
@service_api_ns.expect(dataset_update_parser) @service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
@service_api_ns.doc("update_dataset") @service_api_ns.doc("update_dataset")
@service_api_ns.doc(description="Update an existing dataset") @service_api_ns.doc(description="Update an existing dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@@ -372,36 +283,45 @@ class DatasetApi(DatasetApiResource):
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
args = dataset_update_parser.parse_args() payload_dict = service_api_ns.payload or {}
data = request.get_json() payload = DatasetUpdatePayload.model_validate(payload_dict)
update_data = payload.model_dump(exclude_unset=True)
if payload.permission is not None:
update_data["permission"] = str(payload.permission)
if payload.retrieval_model is not None:
update_data["retrieval_model"] = payload.retrieval_model.model_dump()
# check embedding model setting # check embedding model setting
embedding_model_provider = data.get("embedding_model_provider") embedding_model_provider = payload.embedding_model_provider
embedding_model = data.get("embedding_model") embedding_model = payload.embedding_model
if data.get("indexing_technique") == "high_quality" or embedding_model_provider: if payload.indexing_technique == "high_quality" or embedding_model_provider:
if embedding_model_provider and embedding_model: if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting( DatasetService.check_embedding_model_setting(
dataset.tenant_id, embedding_model_provider, embedding_model dataset.tenant_id, embedding_model_provider, embedding_model
) )
retrieval_model = data.get("retrieval_model") retrieval_model = payload.retrieval_model
if ( if (
retrieval_model retrieval_model
and retrieval_model.get("reranking_model") and retrieval_model.reranking_model
and retrieval_model.get("reranking_model").get("reranking_provider_name") and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
dataset.tenant_id, dataset.tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"), retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.get("reranking_model").get("reranking_model_name"), retrieval_model.reranking_model.reranking_model_name,
) )
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission( DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list") current_user,
dataset,
str(payload.permission) if payload.permission else None,
payload.partial_member_list,
) )
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) dataset = DatasetService.update_dataset(dataset_id_str, update_data, current_user)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@@ -410,15 +330,10 @@ class DatasetApi(DatasetApiResource):
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members": if payload.partial_member_list and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
DatasetPermissionService.update_partial_member_list( DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
tenant_id, dataset_id_str, data.get("partial_member_list")
)
# clear partial member list when permission is only_me or all_team_members # clear partial member list when permission is only_me or all_team_members
elif ( elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
DatasetPermissionService.clear_partial_member_list(dataset_id_str) DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@@ -556,7 +471,7 @@ class DatasetTagsApi(DatasetApiResource):
return tags, 200 return tags, 200
@service_api_ns.expect(tag_create_parser) @service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
@service_api_ns.doc("create_dataset_tag") @service_api_ns.doc("create_dataset_tag")
@service_api_ns.doc(description="Add a knowledge type tag") @service_api_ns.doc(description="Add a knowledge type tag")
@service_api_ns.doc( @service_api_ns.doc(
@@ -574,14 +489,13 @@ class DatasetTagsApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = tag_create_parser.parse_args() payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
args["type"] = "knowledge" tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
tag = TagService.save_tags(args)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200 return response, 200
@service_api_ns.expect(tag_update_parser) @service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
@service_api_ns.doc("update_dataset_tag") @service_api_ns.doc("update_dataset_tag")
@service_api_ns.doc(description="Update a knowledge type tag") @service_api_ns.doc(description="Update a knowledge type tag")
@service_api_ns.doc( @service_api_ns.doc(
@@ -598,10 +512,10 @@ class DatasetTagsApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = tag_update_parser.parse_args() payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
args["type"] = "knowledge" params = {"name": payload.name, "type": "knowledge"}
tag_id = args["tag_id"] tag_id = payload.tag_id
tag = TagService.update_tags(args, tag_id) tag = TagService.update_tags(params, tag_id)
binding_count = TagService.get_tag_binding_count(tag_id) binding_count = TagService.get_tag_binding_count(tag_id)
@@ -609,7 +523,7 @@ class DatasetTagsApi(DatasetApiResource):
return response, 200 return response, 200
@service_api_ns.expect(tag_delete_parser) @service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
@service_api_ns.doc("delete_dataset_tag") @service_api_ns.doc("delete_dataset_tag")
@service_api_ns.doc(description="Delete a knowledge type tag") @service_api_ns.doc(description="Delete a knowledge type tag")
@service_api_ns.doc( @service_api_ns.doc(
@@ -623,15 +537,15 @@ class DatasetTagsApi(DatasetApiResource):
@edit_permission_required @edit_permission_required
def delete(self, _, dataset_id): def delete(self, _, dataset_id):
"""Delete a knowledge type tag.""" """Delete a knowledge type tag."""
args = tag_delete_parser.parse_args() payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(args["tag_id"]) TagService.delete_tag(payload.tag_id)
return 204 return 204
@service_api_ns.route("/datasets/tags/binding") @service_api_ns.route("/datasets/tags/binding")
class DatasetTagBindingApi(DatasetApiResource): class DatasetTagBindingApi(DatasetApiResource):
@service_api_ns.expect(tag_binding_parser) @service_api_ns.expect(service_api_ns.models[TagBindingPayload.__name__])
@service_api_ns.doc("bind_dataset_tags") @service_api_ns.doc("bind_dataset_tags")
@service_api_ns.doc(description="Bind tags to a dataset") @service_api_ns.doc(description="Bind tags to a dataset")
@service_api_ns.doc( @service_api_ns.doc(
@@ -648,16 +562,15 @@ class DatasetTagBindingApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = tag_binding_parser.parse_args() payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
args["type"] = "knowledge" TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
TagService.save_tag_binding(args)
return 204 return 204
@service_api_ns.route("/datasets/tags/unbinding") @service_api_ns.route("/datasets/tags/unbinding")
class DatasetTagUnbindingApi(DatasetApiResource): class DatasetTagUnbindingApi(DatasetApiResource):
@service_api_ns.expect(tag_unbinding_parser) @service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
@service_api_ns.doc("unbind_dataset_tag") @service_api_ns.doc("unbind_dataset_tag")
@service_api_ns.doc(description="Unbind a tag from a dataset") @service_api_ns.doc(description="Unbind a tag from a dataset")
@service_api_ns.doc( @service_api_ns.doc(
@@ -674,9 +587,8 @@ class DatasetTagUnbindingApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = tag_unbinding_parser.parse_args() payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
args["type"] = "knowledge" TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
TagService.delete_tag_binding(args)
return 204 return 204

View File

@@ -3,8 +3,8 @@ from typing import Self
from uuid import UUID from uuid import UUID
from flask import request from flask import request
from flask_restx import marshal, reqparse from flask_restx import marshal
from pydantic import BaseModel, model_validator from pydantic import BaseModel, Field, model_validator
from sqlalchemy import desc, select from sqlalchemy import desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@@ -37,22 +37,19 @@ from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService from services.file_service import FileService
# Define parsers for document operations
document_text_create_parser = ( class DocumentTextCreatePayload(BaseModel):
reqparse.RequestParser() name: str
.add_argument("name", type=str, required=True, nullable=False, location="json") text: str
.add_argument("text", type=str, required=True, nullable=False, location="json") process_rule: ProcessRule | None = None
.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") original_document_id: str | None = None
.add_argument("original_document_id", type=str, required=False, location="json") doc_form: str = Field(default="text_model")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") doc_language: str = Field(default="English")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") indexing_technique: str | None = None
.add_argument( retrieval_model: RetrievalModel | None = None
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" embedding_model: str | None = None
) embedding_model_provider: str | None = None
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -72,7 +69,7 @@ class DocumentTextUpdate(BaseModel):
return self return self
for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]: for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]:
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
@@ -83,7 +80,7 @@ for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
class DocumentAddByTextApi(DatasetApiResource): class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents.""" """Resource for documents."""
@service_api_ns.expect(document_text_create_parser) @service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
@service_api_ns.doc("create_document_by_text") @service_api_ns.doc("create_document_by_text")
@service_api_ns.doc(description="Create a new document by providing text content") @service_api_ns.doc(description="Create a new document by providing text content")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@@ -99,7 +96,8 @@ class DocumentAddByTextApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
"""Create document by text.""" """Create document by text."""
args = document_text_create_parser.parse_args() payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
@@ -111,33 +109,29 @@ class DocumentAddByTextApi(DatasetApiResource):
if not dataset.indexing_technique and not args["indexing_technique"]: if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
text = args.get("text") embedding_model_provider = payload.embedding_model_provider
name = args.get("name") embedding_model = payload.embedding_model
if text is None or name is None:
raise ValueError("Both 'text' and 'name' must be non-null values.")
embedding_model_provider = args.get("embedding_model_provider")
embedding_model = args.get("embedding_model")
if embedding_model_provider and embedding_model: if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model") retrieval_model = payload.retrieval_model
if ( if (
retrieval_model retrieval_model
and retrieval_model.get("reranking_model") and retrieval_model.reranking_model
and retrieval_model.get("reranking_model").get("reranking_provider_name") and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
tenant_id, tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"), retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.get("reranking_model").get("reranking_model_name"), retrieval_model.reranking_model.reranking_model_name,
) )
if not current_user: if not current_user:
raise ValueError("current_user is required") raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text( upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id
) )
data_source = { data_source = {
"type": "upload_file", "type": "upload_file",
@@ -174,7 +168,7 @@ class DocumentAddByTextApi(DatasetApiResource):
class DocumentUpdateByTextApi(DatasetApiResource): class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents.""" """Resource for update documents."""
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True) @service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
@service_api_ns.doc("update_document_by_text") @service_api_ns.doc("update_document_by_text")
@service_api_ns.doc(description="Update an existing document by providing text content") @service_api_ns.doc(description="Update an existing document by providing text content")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@@ -189,22 +183,23 @@ class DocumentUpdateByTextApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by text.""" """Update document by text."""
args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True) payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first()
args = payload.model_dump(exclude_none=True)
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
retrieval_model = args.get("retrieval_model") retrieval_model = payload.retrieval_model
if ( if (
retrieval_model retrieval_model
and retrieval_model.get("reranking_model") and retrieval_model.reranking_model
and retrieval_model.get("reranking_model").get("reranking_provider_name") and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
tenant_id, tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"), retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.get("reranking_model").get("reranking_model_name"), retrieval_model.reranking_model.reranking_model_name,
) )
# indexing_technique is already set in dataset since this is an update # indexing_technique is already set in dataset since this is an update

View File

@@ -1,9 +1,11 @@
from typing import Literal from typing import Literal
from flask_login import current_user from flask_login import current_user
from flask_restx import marshal, reqparse from flask_restx import marshal
from pydantic import BaseModel
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
from fields.dataset_fields import dataset_metadata_fields from fields.dataset_fields import dataset_metadata_fields
@@ -14,25 +16,18 @@ from services.entities.knowledge_entities.knowledge_entities import (
) )
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
# Define parsers for metadata APIs
metadata_create_parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json", help="Metadata type")
.add_argument("name", type=str, required=True, nullable=False, location="json", help="Metadata name")
)
metadata_update_parser = reqparse.RequestParser().add_argument( class MetadataUpdatePayload(BaseModel):
"name", type=str, required=True, nullable=False, location="json", help="New metadata name" name: str
)
document_metadata_parser = reqparse.RequestParser().add_argument(
"operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data" register_schema_model(service_api_ns, MetadataUpdatePayload)
) register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData)
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata") @service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
class DatasetMetadataCreateServiceApi(DatasetApiResource): class DatasetMetadataCreateServiceApi(DatasetApiResource):
@service_api_ns.expect(metadata_create_parser) @service_api_ns.expect(service_api_ns.models[MetadataArgs.__name__])
@service_api_ns.doc("create_dataset_metadata") @service_api_ns.doc("create_dataset_metadata")
@service_api_ns.doc(description="Create metadata for a dataset") @service_api_ns.doc(description="Create metadata for a dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@@ -46,8 +41,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
"""Create metadata for a dataset.""" """Create metadata for a dataset."""
args = metadata_create_parser.parse_args() metadata_args = MetadataArgs.model_validate(service_api_ns.payload or {})
metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@@ -79,7 +73,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>") @service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
class DatasetMetadataServiceApi(DatasetApiResource): class DatasetMetadataServiceApi(DatasetApiResource):
@service_api_ns.expect(metadata_update_parser) @service_api_ns.expect(service_api_ns.models[MetadataUpdatePayload.__name__])
@service_api_ns.doc("update_dataset_metadata") @service_api_ns.doc("update_dataset_metadata")
@service_api_ns.doc(description="Update metadata name") @service_api_ns.doc(description="Update metadata name")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"})
@@ -93,7 +87,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id, dataset_id, metadata_id): def patch(self, tenant_id, dataset_id, metadata_id):
"""Update metadata name.""" """Update metadata name."""
args = metadata_update_parser.parse_args() payload = MetadataUpdatePayload.model_validate(service_api_ns.payload or {})
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id) metadata_id_str = str(metadata_id)
@@ -102,7 +96,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"]) metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
return marshal(metadata, dataset_metadata_fields), 200 return marshal(metadata, dataset_metadata_fields), 200
@service_api_ns.doc("delete_dataset_metadata") @service_api_ns.doc("delete_dataset_metadata")
@@ -175,7 +169,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata") @service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
class DocumentMetadataEditServiceApi(DatasetApiResource): class DocumentMetadataEditServiceApi(DatasetApiResource):
@service_api_ns.expect(document_metadata_parser) @service_api_ns.expect(service_api_ns.models[MetadataOperationData.__name__])
@service_api_ns.doc("update_documents_metadata") @service_api_ns.doc("update_documents_metadata")
@service_api_ns.doc(description="Update metadata for multiple documents") @service_api_ns.doc(description="Update metadata for multiple documents")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@@ -195,8 +189,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
args = document_metadata_parser.parse_args() metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {})
metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args) MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@@ -4,12 +4,12 @@ from collections.abc import Generator
from typing import Any from typing import Any
from flask import request from flask import request
from flask_restx import reqparse from pydantic import BaseModel
from flask_restx.reqparse import ParseResult, RequestParser
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
import services import services
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import PipelineRunError from controllers.service_api.dataset.error import PipelineRunError
from controllers.service_api.wraps import DatasetApiResource from controllers.service_api.wraps import DatasetApiResource
@@ -22,11 +22,25 @@ from models.dataset import Pipeline
from models.engine import db from models.engine import db
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.file_service import FileService from services.file_service import FileService
from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity from services.rag_pipeline.entity.pipeline_service_api_entities import (
DatasourceNodeRunApiEntity,
PipelineRunApiEntity,
)
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline import RagPipelineService
class DatasourceNodeRunPayload(BaseModel):
inputs: dict[str, Any]
datasource_type: str
credential_id: str | None = None
is_published: bool
register_schema_model(service_api_ns, DatasourceNodeRunPayload)
register_schema_model(service_api_ns, PipelineRunApiEntity)
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins") @service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
class DatasourcePluginsApi(DatasetApiResource): class DatasourcePluginsApi(DatasetApiResource):
"""Resource for datasource plugins.""" """Resource for datasource plugins."""
@@ -88,22 +102,20 @@ class DatasourceNodeRunApi(DatasetApiResource):
401: "Unauthorized - invalid API token", 401: "Unauthorized - invalid API token",
} }
) )
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
def post(self, tenant_id: str, dataset_id: str, node_id: str): def post(self, tenant_id: str, dataset_id: str, node_id: str):
"""Resource for getting datasource plugins.""" """Resource for getting datasource plugins."""
# Get query parameter to determine published or draft payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
parser: RequestParser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
.add_argument("is_published", type=bool, required=True, location="json")
)
args: ParseResult = parser.parse_args()
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService() rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id) pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(
{
**payload.model_dump(exclude_none=True),
"pipeline_id": str(pipeline.id),
"node_id": node_id,
}
)
return helper.compact_generate_response( return helper.compact_generate_response(
PipelineGenerator.convert_to_event_stream( PipelineGenerator.convert_to_event_stream(
rag_pipeline_service.run_datasource_workflow_node( rag_pipeline_service.run_datasource_workflow_node(
@@ -147,25 +159,10 @@ class PipelineRunApi(DatasetApiResource):
401: "Unauthorized - invalid API token", 401: "Unauthorized - invalid API token",
} }
) )
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
def post(self, tenant_id: str, dataset_id: str): def post(self, tenant_id: str, dataset_id: str):
"""Resource for running a rag pipeline.""" """Resource for running a rag pipeline."""
parser: RequestParser = ( payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {})
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_published", type=bool, required=True, default=True, location="json")
.add_argument(
"response_mode",
type=str,
required=True,
choices=["streaming", "blocking"],
default="blocking",
location="json",
)
)
args: ParseResult = parser.parse_args()
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
@@ -176,9 +173,9 @@ class PipelineRunApi(DatasetApiResource):
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate( response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
pipeline=pipeline, pipeline=pipeline,
user=current_user, user=current_user,
args=args, args=payload.model_dump(),
invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER,
streaming=args.get("response_mode") == "streaming", streaming=payload.response_mode == "streaming",
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)

View File

@@ -1,8 +1,12 @@
from typing import Any
from flask import request from flask import request
from flask_restx import marshal, reqparse from flask_restx import marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from configs import dify_config from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import ( from controllers.service_api.wraps import (
@@ -24,34 +28,42 @@ from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexing
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
# Define parsers for segment operations
segment_create_parser = reqparse.RequestParser().add_argument(
"segments", type=list, required=False, nullable=True, location="json"
)
segment_list_parser = ( class SegmentCreatePayload(BaseModel):
reqparse.RequestParser() segments: list[dict[str, Any]] | None = None
.add_argument("status", type=str, action="append", default=[], location="args")
.add_argument("keyword", type=str, default=None, location="args")
)
segment_update_parser = reqparse.RequestParser().add_argument(
"segment", type=dict, required=False, nullable=True, location="json"
)
child_chunk_create_parser = reqparse.RequestParser().add_argument( class SegmentListQuery(BaseModel):
"content", type=str, required=True, nullable=False, location="json" status: list[str] = Field(default_factory=list)
) keyword: str | None = None
child_chunk_list_parser = (
reqparse.RequestParser()
.add_argument("limit", type=int, default=20, location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
child_chunk_update_parser = reqparse.RequestParser().add_argument( class SegmentUpdatePayload(BaseModel):
"content", type=str, required=True, nullable=False, location="json" segment: SegmentUpdateArgs
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkListQuery(BaseModel):
limit: int = Field(default=20, ge=1)
keyword: str | None = None
page: int = Field(default=1, ge=1)
class ChildChunkUpdatePayload(BaseModel):
content: str
register_schema_models(
service_api_ns,
SegmentCreatePayload,
SegmentListQuery,
SegmentUpdatePayload,
ChildChunkCreatePayload,
ChildChunkListQuery,
ChildChunkUpdatePayload,
) )
@@ -59,7 +71,7 @@ child_chunk_update_parser = reqparse.RequestParser().add_argument(
class SegmentApi(DatasetApiResource): class SegmentApi(DatasetApiResource):
"""Resource for segments.""" """Resource for segments."""
@service_api_ns.expect(segment_create_parser) @service_api_ns.expect(service_api_ns.models[SegmentCreatePayload.__name__])
@service_api_ns.doc("create_segments") @service_api_ns.doc("create_segments")
@service_api_ns.doc(description="Create segments in a document") @service_api_ns.doc(description="Create segments in a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@@ -106,20 +118,20 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
# validate args # validate args
args = segment_create_parser.parse_args() payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {})
if args["segments"] is not None: if payload.segments is not None:
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
if segments_limit > 0 and len(args["segments"]) > segments_limit: if segments_limit > 0 and len(payload.segments) > segments_limit:
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.") raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
for args_item in args["segments"]: for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document) SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args["segments"], document, dataset) segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
else: else:
return {"error": "Segments is required"}, 400 return {"error": "Segments is required"}, 400
@service_api_ns.expect(segment_list_parser) @service_api_ns.expect(service_api_ns.models[SegmentListQuery.__name__])
@service_api_ns.doc("list_segments") @service_api_ns.doc("list_segments")
@service_api_ns.doc(description="List segments in a document") @service_api_ns.doc(description="List segments in a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@@ -160,13 +172,18 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
args = segment_list_parser.parse_args() args = SegmentListQuery.model_validate(
{
"status": request.args.getlist("status"),
"keyword": request.args.get("keyword"),
}
)
segments, total = SegmentService.get_segments( segments, total = SegmentService.get_segments(
document_id=document_id, document_id=document_id,
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
status_list=args["status"], status_list=args.status,
keyword=args["keyword"], keyword=args.keyword,
page=page, page=page,
limit=limit, limit=limit,
) )
@@ -217,7 +234,7 @@ class DatasetSegmentApi(DatasetApiResource):
SegmentService.delete_segment(segment, document, dataset) SegmentService.delete_segment(segment, document, dataset)
return 204 return 204
@service_api_ns.expect(segment_update_parser) @service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
@service_api_ns.doc("update_segment") @service_api_ns.doc("update_segment")
@service_api_ns.doc(description="Update a specific segment") @service_api_ns.doc(description="Update a specific segment")
@service_api_ns.doc( @service_api_ns.doc(
@@ -265,12 +282,9 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# validate args payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
args = segment_update_parser.parse_args()
updated_segment = SegmentService.update_segment( updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
@service_api_ns.doc("get_segment") @service_api_ns.doc("get_segment")
@@ -308,7 +322,7 @@ class DatasetSegmentApi(DatasetApiResource):
class ChildChunkApi(DatasetApiResource): class ChildChunkApi(DatasetApiResource):
"""Resource for child chunks.""" """Resource for child chunks."""
@service_api_ns.expect(child_chunk_create_parser) @service_api_ns.expect(service_api_ns.models[ChildChunkCreatePayload.__name__])
@service_api_ns.doc("create_child_chunk") @service_api_ns.doc("create_child_chunk")
@service_api_ns.doc(description="Create a new child chunk for a segment") @service_api_ns.doc(description="Create a new child chunk for a segment")
@service_api_ns.doc( @service_api_ns.doc(
@@ -360,16 +374,16 @@ class ChildChunkApi(DatasetApiResource):
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
# validate args # validate args
args = child_chunk_create_parser.parse_args() payload = ChildChunkCreatePayload.model_validate(service_api_ns.payload or {})
try: try:
child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset) child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200 return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@service_api_ns.expect(child_chunk_list_parser) @service_api_ns.expect(service_api_ns.models[ChildChunkListQuery.__name__])
@service_api_ns.doc("list_child_chunks") @service_api_ns.doc("list_child_chunks")
@service_api_ns.doc(description="List child chunks for a segment") @service_api_ns.doc(description="List child chunks for a segment")
@service_api_ns.doc( @service_api_ns.doc(
@@ -400,11 +414,17 @@ class ChildChunkApi(DatasetApiResource):
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
args = child_chunk_list_parser.parse_args() args = ChildChunkListQuery.model_validate(
{
"limit": request.args.get("limit", default=20, type=int),
"keyword": request.args.get("keyword"),
"page": request.args.get("page", default=1, type=int),
}
)
page = args["page"] page = args.page
limit = min(args["limit"], 100) limit = min(args.limit, 100)
keyword = args["keyword"] keyword = args.keyword
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
@@ -480,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource):
return 204 return 204
@service_api_ns.expect(child_chunk_update_parser) @service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
@service_api_ns.doc("update_child_chunk") @service_api_ns.doc("update_child_chunk")
@service_api_ns.doc(description="Update a specific child chunk") @service_api_ns.doc(description="Update a specific child chunk")
@service_api_ns.doc( @service_api_ns.doc(
@@ -533,10 +553,10 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Child chunk not found.") raise NotFound("Child chunk not found.")
# validate args # validate args
args = child_chunk_update_parser.parse_args() payload = ChildChunkUpdatePayload.model_validate(service_api_ns.payload or {})
try: try:
child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset) child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))

View File

@@ -101,8 +101,8 @@ class HitTestingService:
dataset: Dataset, dataset: Dataset,
query: str, query: str,
account: Account, account: Account,
external_retrieval_model: dict, external_retrieval_model: dict | None = None,
metadata_filtering_conditions: dict, metadata_filtering_conditions: dict | None = None,
): ):
if dataset.provider != "external": if dataset.provider != "external":
return { return {

View File

@@ -125,7 +125,7 @@ class TestPartnerTenants:
resource = PartnerTenants() resource = PartnerTenants()
# Act & Assert # Act & Assert
# reqparse will raise BadRequest for missing required field # Validation should raise BadRequest for missing required field
with pytest.raises(BadRequest): with pytest.raises(BadRequest):
resource.put(partner_key_encoded) resource.put(partner_key_encoded)

View File

@@ -256,24 +256,18 @@ class TestFilePreviewApi:
mock_app, # App query for tenant validation mock_app, # App query for tenant validation
] ]
with patch("controllers.service_api.app.file_preview.reqparse") as mock_reqparse: # Test the core logic directly without Flask decorators
# Mock request parsing # Validate file ownership
mock_parser = Mock() result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id)
mock_parser.parse_args.return_value = {"as_attachment": False} assert result_message_file == mock_message_file
mock_reqparse.RequestParser.return_value = mock_parser assert result_upload_file == mock_upload_file
# Test the core logic directly without Flask decorators # Test file response building
# Validate file ownership response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) assert response is not None
assert result_message_file == mock_message_file
assert result_upload_file == mock_upload_file
# Test file response building # Verify storage was called correctly
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) mock_storage.load.assert_not_called() # Since we're testing components separately
assert response is not None
# Verify storage was called correctly
mock_storage.load.assert_not_called() # Since we're testing components separately
@patch("controllers.service_api.app.file_preview.storage") @patch("controllers.service_api.app.file_preview.storage")
def test_storage_error_handling( def test_storage_error_handling(

View File

@@ -2,8 +2,6 @@ from pathlib import Path
from unittest.mock import Mock, create_autospec, patch from unittest.mock import Mock, create_autospec, patch
import pytest import pytest
from flask_restx import reqparse
from werkzeug.exceptions import BadRequest
from models.account import Account from models.account import Account
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
@@ -77,60 +75,39 @@ class TestMetadataBugCompleteValidation:
assert type_column.nullable is False, "type column should be nullable=False" assert type_column.nullable is False, "type column should be nullable=False"
assert name_column.nullable is False, "name column should be nullable=False" assert name_column.nullable is False, "name column should be nullable=False"
def test_4_fixed_api_layer_rejects_null(self, app): def test_4_fixed_api_layer_rejects_null(self):
"""Test Layer 4: Fixed API configuration properly rejects null values.""" """Test Layer 4: Fixed API configuration properly rejects null values using Pydantic."""
# Test Console API create endpoint (fixed) with pytest.raises((ValueError, TypeError)):
parser = ( MetadataArgs.model_validate({"type": None, "name": None})
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): with pytest.raises((ValueError, TypeError)):
with pytest.raises(BadRequest): MetadataArgs.model_validate({"type": "string", "name": None})
parser.parse_args()
# Test with just name being null with pytest.raises((ValueError, TypeError)):
with app.test_request_context(json={"type": "string", "name": None}, content_type="application/json"): MetadataArgs.model_validate({"type": None, "name": "test"})
with pytest.raises(BadRequest):
parser.parse_args()
# Test with just type being null def test_5_fixed_api_accepts_valid_values(self):
with app.test_request_context(json={"type": None, "name": "test"}, content_type="application/json"):
with pytest.raises(BadRequest):
parser.parse_args()
def test_5_fixed_api_accepts_valid_values(self, app):
"""Test that fixed API still accepts valid non-null values.""" """Test that fixed API still accepts valid non-null values."""
parser = ( args = MetadataArgs.model_validate({"type": "string", "name": "valid_name"})
reqparse.RequestParser() assert args.type == "string"
.add_argument("type", type=str, required=True, nullable=False, location="json") assert args.name == "valid_name"
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
with app.test_request_context(json={"type": "string", "name": "valid_name"}, content_type="application/json"): def test_6_simulated_buggy_behavior(self):
args = parser.parse_args() """Test simulating the original buggy behavior by bypassing Pydantic validation."""
assert args["type"] == "string" mock_metadata_args = Mock()
assert args["name"] == "valid_name" mock_metadata_args.name = None
mock_metadata_args.type = None
def test_6_simulated_buggy_behavior(self, app): mock_user = create_autospec(Account, instance=True)
"""Test simulating the original buggy behavior with nullable=True.""" mock_user.current_tenant_id = "tenant-123"
# Simulate the old buggy configuration mock_user.id = "user-456"
buggy_parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=True, location="json")
.add_argument("name", type=str, required=True, nullable=True, location="json")
)
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): with patch(
# This would pass in the buggy version "services.metadata_service.current_account_with_tenant",
args = buggy_parser.parse_args() return_value=(mock_user, mock_user.current_tenant_id),
assert args["type"] is None ):
assert args["name"] is None with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
# But would crash when trying to create MetadataArgs
with pytest.raises((ValueError, TypeError)):
MetadataArgs.model_validate(args)
def test_7_end_to_end_validation_layers(self): def test_7_end_to_end_validation_layers(self):
"""Test all validation layers work together correctly.""" """Test all validation layers work together correctly."""

View File

@@ -1,7 +1,6 @@
from unittest.mock import Mock, create_autospec, patch from unittest.mock import Mock, create_autospec, patch
import pytest import pytest
from flask_restx import reqparse
from models.account import Account from models.account import Account
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
@@ -51,76 +50,16 @@ class TestMetadataNullableBug:
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", None) MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
def test_api_parser_accepts_null_values(self, app): def test_api_layer_now_uses_pydantic_validation(self):
"""Test that API parser configuration incorrectly accepts null values.""" """Verify that API layer relies on Pydantic validation instead of reqparse."""
# Simulate the current API parser configuration invalid_payload = {"type": None, "name": None}
parser = ( with pytest.raises((ValueError, TypeError)):
reqparse.RequestParser() MetadataArgs.model_validate(invalid_payload)
.add_argument("type", type=str, required=True, nullable=True, location="json")
.add_argument("name", type=str, required=True, nullable=True, location="json")
)
# Simulate request data with null values valid_payload = {"type": "string", "name": "valid"}
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): args = MetadataArgs.model_validate(valid_payload)
# This should parse successfully due to nullable=True assert args.type == "string"
args = parser.parse_args() assert args.name == "valid"
# Verify that null values are accepted
assert args["type"] is None
assert args["name"] is None
# This demonstrates the bug: API accepts None but business logic will crash
def test_integration_bug_scenario(self, app):
"""Test the complete bug scenario from API to service layer."""
# Step 1: API parser accepts null values (current buggy behavior)
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=True, location="json")
.add_argument("name", type=str, required=True, nullable=True, location="json")
)
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
args = parser.parse_args()
# Step 2: Try to create MetadataArgs with None values
# This should fail at Pydantic validation level
with pytest.raises((ValueError, TypeError)):
metadata_args = MetadataArgs.model_validate(args)
# Step 3: If we bypass Pydantic (simulating the bug scenario)
# Move this outside the request context to avoid Flask-Login issues
mock_metadata_args = Mock()
mock_metadata_args.name = None # From args["name"]
mock_metadata_args.type = None # From args["type"]
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
# Step 4: Service layer crashes on len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
def test_correct_nullable_false_configuration_works(self, app):
"""Test that the correct nullable=False configuration works as expected."""
# This tests the FIXED configuration
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
# This should fail with BadRequest due to nullable=False
from werkzeug.exceptions import BadRequest
with pytest.raises(BadRequest):
parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":