From 345ba80942902038bf8aaef17dddad5849b2706e Mon Sep 17 00:00:00 2001 From: Lillian <11332799+Lillian68@users.noreply.github.com> Date: Mon, 25 May 2026 15:33:32 +0800 Subject: [PATCH] fix: type mismatches (route says uuid: but handler says str) (#36612) --- api/controllers/console/app/message.py | 2 +- .../console/app/workflow_draft_variable.py | 29 ++-- api/controllers/console/app/workflow_run.py | 2 +- .../console/datasets/datasets_document.py | 4 +- .../rag_pipeline_draft_variable.py | 37 ++--- .../rag_pipeline/rag_pipeline_workflow.py | 2 +- api/controllers/service_api/app/annotation.py | 8 +- .../service_api/app/file_preview.py | 7 +- .../rag_pipeline/rag_pipeline_workflow.py | 22 +-- .../service_api/dataset/segment.py | 130 ++++++++++++------ 10 files changed, 149 insertions(+), 94 deletions(-) diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index faa1e0fcda..7445fed86c 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -417,7 +417,7 @@ class MessageApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model, message_id: str): + def get(self, app_model, message_id: UUID): message_id_str = str(message_id) message = db.session.scalar( diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 3c887c33dc..83f0d1dde6 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -2,6 +2,7 @@ import logging from collections.abc import Callable from functools import wraps from typing import Any, TypedDict +from uuid import UUID from flask import Response, request from flask_restx import Resource, fields, marshal, marshal_with @@ -345,14 +346,15 @@ class VariableApi(Resource): @console_ns.response(404, "Variable not found") @_api_prerequisite @marshal_with(workflow_draft_variable_model) - def get(self, app_model: App, variable_id: str): + def get(self, app_model: App, variable_id: UUID): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) + variable_id_str = str(variable_id) variable = _ensure_variable_access( - variable=draft_var_srv.get_variable(variable_id=variable_id), + variable=draft_var_srv.get_variable(variable_id=variable_id_str), app_id=app_model.id, - variable_id=variable_id, + variable_id=variable_id_str, ) return variable @@ -363,7 +365,7 @@ class VariableApi(Resource): @console_ns.response(404, "Variable not found") @_api_prerequisite @marshal_with(workflow_draft_variable_model) - def patch(self, app_model: App, variable_id: str): + def patch(self, app_model: App, variable_id: UUID): # Request payload for file types: # # Local File: @@ -390,10 +392,11 @@ class VariableApi(Resource): ) args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {}) + variable_id_str = str(variable_id) variable = _ensure_variable_access( - variable=draft_var_srv.get_variable(variable_id=variable_id), + variable=draft_var_srv.get_variable(variable_id=variable_id_str), app_id=app_model.id, - variable_id=variable_id, + variable_id=variable_id_str, ) new_name = args_model.name @@ -434,14 +437,15 @@ class VariableApi(Resource): @console_ns.response(204, "Variable deleted successfully") @console_ns.response(404, "Variable not found") @_api_prerequisite - def delete(self, app_model: App, variable_id: str): + def delete(self, app_model: App, variable_id: UUID): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) + variable_id_str = str(variable_id) variable = _ensure_variable_access( - variable=draft_var_srv.get_variable(variable_id=variable_id), + variable=draft_var_srv.get_variable(variable_id=variable_id_str), app_id=app_model.id, - variable_id=variable_id, + variable_id=variable_id_str, ) draft_var_srv.delete_variable(variable) db.session.commit() @@ -457,7 +461,7 @@ class VariableResetApi(Resource): @console_ns.response(204, "Variable reset (no content)") @console_ns.response(404, "Variable not found") @_api_prerequisite - def put(self, app_model: App, variable_id: str): + def put(self, app_model: App, variable_id: UUID): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) @@ -468,10 +472,11 @@ class VariableResetApi(Resource): raise NotFoundError( f"Draft workflow not found, app_id={app_model.id}", ) + variable_id_str = str(variable_id) variable = _ensure_variable_access( - variable=draft_var_srv.get_variable(variable_id=variable_id), + variable=draft_var_srv.get_variable(variable_id=variable_id_str), app_id=app_model.id, - variable_id=variable_id, + variable_id=variable_id_str, ) resetted = draft_var_srv.reset_variable(draft_workflow, variable) diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 2d48b59de2..9b46aeacb8 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -189,7 +189,7 @@ class WorkflowRunExportApi(Resource): @login_required @account_initialization_required @get_app_model() - def get(self, app_model: App, run_id: str): + def get(self, app_model: App, run_id: UUID): tenant_id = str(app_model.tenant_id) app_id = str(app_model.id) run_id_str = str(run_id) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index d387834e9b..bbfbe00a67 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -979,7 +979,7 @@ class DocumentDownloadApi(DocumentResource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def get(self, dataset_id: str, document_id: str) -> dict[str, Any]: + def get(self, dataset_id: UUID, document_id: UUID) -> dict[str, Any]: # Reuse the shared permission/tenant checks implemented in DocumentResource. document = self.get_document(str(dataset_id), str(document_id)) return {"url": DocumentService.get_document_download_url(document)} @@ -996,7 +996,7 @@ class DocumentBatchDownloadZipApi(DocumentResource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__]) - def post(self, dataset_id: str): + def post(self, dataset_id: UUID): """Stream a ZIP archive containing the requested uploaded documents.""" # Parse and validate request payload. payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {}) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index b31d73f27d..4baedf662e 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -1,6 +1,7 @@ import logging from collections.abc import Callable from typing import Any, NoReturn +from uuid import UUID from flask import Response, request from flask_restx import Resource, marshal, marshal_with @@ -168,21 +169,22 @@ class RagPipelineVariableApi(Resource): @_api_prerequisite @marshal_with(workflow_draft_variable_model) - def get(self, pipeline: Pipeline, variable_id: str): + def get(self, pipeline: Pipeline, variable_id: UUID): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) - variable = draft_var_srv.get_variable(variable_id=variable_id) + variable_id_str = str(variable_id) + variable = draft_var_srv.get_variable(variable_id=variable_id_str) if variable is None: - raise NotFoundError(description=f"variable not found, id={variable_id}") + raise NotFoundError(description=f"variable not found, id={variable_id_str}") if variable.app_id != pipeline.id: - raise NotFoundError(description=f"variable not found, id={variable_id}") + raise NotFoundError(description=f"variable not found, id={variable_id_str}") return variable @_api_prerequisite @marshal_with(workflow_draft_variable_model) @console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__]) - def patch(self, pipeline: Pipeline, variable_id: str): + def patch(self, pipeline: Pipeline, variable_id: UUID): # Request payload for file types: # # Local File: @@ -210,11 +212,12 @@ class RagPipelineVariableApi(Resource): payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {}) args = payload.model_dump(exclude_none=True) - variable = draft_var_srv.get_variable(variable_id=variable_id) + variable_id_str = str(variable_id) + variable = draft_var_srv.get_variable(variable_id=variable_id_str) if variable is None: - raise NotFoundError(description=f"variable not found, id={variable_id}") + raise NotFoundError(description=f"variable not found, id={variable_id_str}") if variable.app_id != pipeline.id: - raise NotFoundError(description=f"variable not found, id={variable_id}") + raise NotFoundError(description=f"variable not found, id={variable_id_str}") new_name = args.get(self._PATCH_NAME_FIELD, None) raw_value = args.get(self._PATCH_VALUE_FIELD, None) @@ -250,15 +253,16 @@ class RagPipelineVariableApi(Resource): return variable @_api_prerequisite - def delete(self, pipeline: Pipeline, variable_id: str): + def delete(self, pipeline: Pipeline, variable_id: UUID): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) - variable = draft_var_srv.get_variable(variable_id=variable_id) + variable_id_str = str(variable_id) + variable = draft_var_srv.get_variable(variable_id=variable_id_str) if variable is None: - raise NotFoundError(description=f"variable not found, id={variable_id}") + raise NotFoundError(description=f"variable not found, id={variable_id_str}") if variable.app_id != pipeline.id: - raise NotFoundError(description=f"variable not found, id={variable_id}") + raise NotFoundError(description=f"variable not found, id={variable_id_str}") draft_var_srv.delete_variable(variable) db.session.commit() return Response("", 204) @@ -267,7 +271,7 @@ class RagPipelineVariableApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/variables//reset") class RagPipelineVariableResetApi(Resource): @_api_prerequisite - def put(self, pipeline: Pipeline, variable_id: str): + def put(self, pipeline: Pipeline, variable_id: UUID): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) @@ -278,11 +282,12 @@ class RagPipelineVariableResetApi(Resource): raise NotFoundError( f"Draft workflow not found, pipeline_id={pipeline.id}", ) - variable = draft_var_srv.get_variable(variable_id=variable_id) + variable_id_str = str(variable_id) + variable = draft_var_srv.get_variable(variable_id=variable_id_str) if variable is None: - raise NotFoundError(description=f"variable not found, id={variable_id}") + raise NotFoundError(description=f"variable not found, id={variable_id_str}") if variable.app_id != pipeline.id: - raise NotFoundError(description=f"variable not found, id={variable_id}") + raise NotFoundError(description=f"variable not found, id={variable_id_str}") resetted = draft_var_srv.reset_variable(draft_workflow, variable) db.session.commit() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index a7727513df..5d6a779b5a 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -901,7 +901,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource): @login_required @account_initialization_required @get_rag_pipeline - def get(self, pipeline: Pipeline, run_id: str): + def get(self, pipeline: Pipeline, run_id: UUID): """ Get workflow run node execution list """ diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 9f6b1cf52e..12ab692adb 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -174,11 +174,11 @@ class AnnotationUpdateDeleteApi(Resource): ) @validate_app_token @edit_permission_required - def put(self, app_model: App, annotation_id: str): + def put(self, app_model: App, annotation_id: UUID): """Update an existing annotation.""" payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}) update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer} - annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id) + annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, str(annotation_id)) response = Annotation.model_validate(annotation, from_attributes=True) return response.model_dump(mode="json") @@ -195,7 +195,7 @@ class AnnotationUpdateDeleteApi(Resource): ) @validate_app_token @edit_permission_required - def delete(self, app_model: App, annotation_id: str): + def delete(self, app_model: App, annotation_id: UUID): """Delete an annotation.""" - AppAnnotationService.delete_app_annotation(app_model.id, annotation_id) + AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id)) return "", 204 diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py index 5e7847d784..44f765d866 100644 --- a/api/controllers/service_api/app/file_preview.py +++ b/api/controllers/service_api/app/file_preview.py @@ -1,5 +1,6 @@ import logging from urllib.parse import quote +from uuid import UUID from flask import Response, request from flask_restx import Resource @@ -50,20 +51,20 @@ class FilePreviewApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - def get(self, app_model: App, end_user: EndUser, file_id: str): + def get(self, app_model: App, end_user: EndUser, file_id: UUID): """ Preview/Download a file that was uploaded via Service API. Provides secure file preview/download functionality. Files can only be accessed if they belong to messages within the requesting app's context. """ - file_id = str(file_id) + file_id_str = str(file_id) # Parse query parameters args = FilePreviewQuery.model_validate(request.args.to_dict()) # Validate file ownership and get file objects - _, upload_file = self._validate_file_ownership(file_id, app_model.id) + _, upload_file = self._validate_file_ownership(file_id_str, app_model.id) # Get file content generator try: diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index 8bc43bccd5..19b1d008b7 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -1,5 +1,6 @@ from collections.abc import Generator from typing import Any +from uuid import UUID from flask import request from pydantic import BaseModel @@ -64,10 +65,11 @@ class DatasourcePluginsApi(DatasetApiResource): 401: "Unauthorized - invalid API token", } ) - def get(self, tenant_id: str, dataset_id: str): + def get(self, tenant_id: str, dataset_id: UUID): """Resource for getting datasource plugins.""" + dataset_id_str = str(dataset_id) # Verify dataset ownership - stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id) + stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str) dataset = db.session.scalar(stmt) if not dataset: raise NotFound("Dataset not found.") @@ -77,7 +79,7 @@ class DatasourcePluginsApi(DatasetApiResource): rag_pipeline_service: RagPipelineService = RagPipelineService() datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins( - tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published + tenant_id=tenant_id, dataset_id=dataset_id_str, is_published=is_published ) return datasource_plugins, 200 @@ -109,10 +111,11 @@ class DatasourceNodeRunApi(DatasetApiResource): } ) @service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__]) - def post(self, tenant_id: str, dataset_id: str, node_id: str): + def post(self, tenant_id: str, dataset_id: UUID, node_id: str): """Resource for getting datasource plugins.""" + dataset_id_str = str(dataset_id) # Verify dataset ownership - stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id) + stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str) dataset = db.session.scalar(stmt) if not dataset: raise NotFound("Dataset not found.") @@ -120,7 +123,7 @@ class DatasourceNodeRunApi(DatasetApiResource): payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {}) assert isinstance(current_user, Account) rag_pipeline_service: RagPipelineService = RagPipelineService() - pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id) + pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id_str) datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate( { **payload.model_dump(exclude_none=True), @@ -172,10 +175,11 @@ class PipelineRunApi(DatasetApiResource): } ) @service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__]) - def post(self, tenant_id: str, dataset_id: str): + def post(self, tenant_id: str, dataset_id: UUID): """Resource for running a rag pipeline.""" + dataset_id_str = str(dataset_id) # Verify dataset ownership - stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id) + stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str) dataset = db.session.scalar(stmt) if not dataset: raise NotFound("Dataset not found.") @@ -186,7 +190,7 @@ class PipelineRunApi(DatasetApiResource): raise Forbidden() rag_pipeline_service: RagPipelineService = RagPipelineService() - pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id) + pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id_str) try: response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate( pipeline=pipeline, diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 5992fa7410..34e1710068 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,4 +1,5 @@ from typing import Any +from uuid import UUID from flask import request from flask_restx import marshal @@ -107,17 +108,19 @@ class SegmentApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id: str, dataset_id: str, document_id: str): + def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): _, current_tenant_id = current_account_with_tenant() """Create single segment.""" + dataset_id_str = str(dataset_id) # check dataset dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1) ) if not dataset: raise NotFound("Dataset not found.") + document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset.id, document_id) + document = DocumentService.get_document(dataset.id, document_id_str) if not document: raise NotFound("Document not found.") if document.indexing_status != "completed": @@ -150,7 +153,10 @@ class SegmentApi(DatasetApiResource): for args_item in payload.segments: SegmentService.segment_create_args_validate(args_item, document) segments = SegmentService.multi_create_segment(payload.segments, document, dataset) - return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200 + return { + "data": _marshal_segments_with_summary(segments, dataset_id_str), + "doc_form": document.doc_form, + }, 200 else: return {"error": "Segments is required"}, 400 @@ -165,19 +171,21 @@ class SegmentApi(DatasetApiResource): 404: "Dataset or document not found", } ) - def get(self, tenant_id: str, dataset_id: str, document_id: str): + def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID): _, current_tenant_id = current_account_with_tenant() """Get segments.""" # check dataset page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) + dataset_id_str = str(dataset_id) dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1) ) if not dataset: raise NotFound("Dataset not found.") + document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset.id, document_id) + document = DocumentService.get_document(dataset.id, document_id_str) if not document: raise NotFound("Document not found.") # check embedding model setting @@ -205,7 +213,7 @@ class SegmentApi(DatasetApiResource): ) segments, total = SegmentService.get_segments( - document_id=document_id, + document_id=document_id_str, tenant_id=current_tenant_id, status_list=args.status, keyword=args.keyword, @@ -214,7 +222,7 @@ class SegmentApi(DatasetApiResource): ) response = { - "data": _marshal_segments_with_summary(segments, dataset_id), + "data": _marshal_segments_with_summary(segments, dataset_id_str), "doc_form": document.doc_form, "total": total, "has_more": len(segments) == limit, @@ -240,22 +248,25 @@ class DatasetSegmentApi(DatasetApiResource): } ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): + def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID): _, current_tenant_id = current_account_with_tenant() + dataset_id_str = str(dataset_id) # check dataset dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1) ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) + document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset_id, document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") + segment_id_str = str(segment_id) # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") SegmentService.delete_segment(segment, document, dataset) @@ -276,18 +287,20 @@ class DatasetSegmentApi(DatasetApiResource): ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): + def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID): _, current_tenant_id = current_account_with_tenant() + dataset_id_str = str(dataset_id) # check dataset dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1) ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) + document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset_id, document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: @@ -306,15 +319,19 @@ class DatasetSegmentApi(DatasetApiResource): ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) - # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) + segment_id_str = str(segment_id) + # check segment + segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {}) updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset) - return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200 + return { + "data": _marshal_segment_with_summary(updated_segment, dataset_id_str), + "doc_form": document.doc_form, + }, 200 @service_api_ns.doc("get_segment") @service_api_ns.doc(description="Get a specific segment by ID") @@ -325,26 +342,29 @@ class DatasetSegmentApi(DatasetApiResource): 404: "Dataset, document, or segment not found", } ) - def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): + def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID): _, current_tenant_id = current_account_with_tenant() + dataset_id_str = str(dataset_id) # check dataset dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1) ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) + document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset_id, document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") + segment_id_str = str(segment_id) # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") - return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200 + return {"data": _marshal_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200 @service_api_ns.route( @@ -369,23 +389,26 @@ class ChildChunkApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): + def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID): _, current_tenant_id = current_account_with_tenant() """Create child chunk.""" + dataset_id_str = str(dataset_id) # check dataset dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1) ) if not dataset: raise NotFound("Dataset not found.") + document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset.id, document_id) + document = DocumentService.get_document(dataset.id, document_id_str) if not document: raise NotFound("Document not found.") + segment_id_str = str(segment_id) # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -429,23 +452,26 @@ class ChildChunkApi(DatasetApiResource): 404: "Dataset, document, or segment not found", } ) - def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): + def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID): _, current_tenant_id = current_account_with_tenant() """Get child chunks.""" + dataset_id_str = str(dataset_id) # check dataset dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1) ) if not dataset: raise NotFound("Dataset not found.") + document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset.id, document_id) + document = DocumentService.get_document(dataset.id, document_id_str) if not document: raise NotFound("Document not found.") + segment_id_str = str(segment_id) # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -461,7 +487,9 @@ class ChildChunkApi(DatasetApiResource): limit = min(args.limit, 100) keyword = args.keyword - child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) + child_chunks = SegmentService.get_child_chunks( + segment_id_str, document_id_str, dataset_id_str, page, limit, keyword + ) return { "data": marshal(child_chunks.items, child_chunk_fields), @@ -497,32 +525,38 @@ class DatasetChildChunkApi(DatasetApiResource): ) @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): + def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID): _, current_tenant_id = current_account_with_tenant() """Delete child chunk.""" + dataset_id_str = str(dataset_id) # check dataset dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1) ) if not dataset: raise NotFound("Dataset not found.") + document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset.id, document_id) + document = DocumentService.get_document(dataset.id, document_id_str) if not document: raise NotFound("Document not found.") + segment_id_str = str(segment_id) # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") # validate segment belongs to the specified document - if str(segment.document_id) != str(document_id): + if str(segment.document_id) != str(document_id_str): raise NotFound("Document not found.") + child_chunk_id_str = str(child_chunk_id) # check child chunk - child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id) + child_chunk = SegmentService.get_child_chunk_by_id( + child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id + ) if not child_chunk: raise NotFound("Child chunk not found.") @@ -558,32 +592,38 @@ class DatasetChildChunkApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): + def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID): _, current_tenant_id = current_account_with_tenant() """Update child chunk.""" + dataset_id_str = str(dataset_id) # check dataset dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1) ) if not dataset: raise NotFound("Dataset not found.") + document_id_str = str(document_id) # get document - document = DocumentService.get_document(dataset_id, document_id) + document = DocumentService.get_document(dataset_id_str, document_id_str) if not document: raise NotFound("Document not found.") + segment_id_str = str(segment_id) # get segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") # validate segment belongs to the specified document - if str(segment.document_id) != str(document_id): + if str(segment.document_id) != str(document_id_str): raise NotFound("Segment not found.") + child_chunk_id_str = str(child_chunk_id) # get child chunk - child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id) + child_chunk = SegmentService.get_child_chunk_by_id( + child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id + ) if not child_chunk: raise NotFound("Child chunk not found.")