From 8436470fcbc03abee4f89fea698aacefe4c7d64c Mon Sep 17 00:00:00 2001 From: wdeveloper16 Date: Mon, 13 Apr 2026 05:46:33 +0200 Subject: [PATCH] refactor: replace bare dict with TypedDicts in annotation_service (#34998) --- api/controllers/console/app/annotation.py | 39 ++++++++-- api/controllers/service_api/app/annotation.py | 26 +++++-- api/services/annotation_service.py | 74 +++++++++++++++---- 3 files changed, 110 insertions(+), 29 deletions(-) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 9931bb5dd7..528785931e 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -25,7 +25,13 @@ from fields.annotation_fields import ( ) from libs.helper import uuid_value from libs.login import login_required -from services.annotation_service import AppAnnotationService +from services.annotation_service import ( + AppAnnotationService, + EnableAnnotationArgs, + UpdateAnnotationArgs, + UpdateAnnotationSettingArgs, + UpsertAnnotationArgs, +) DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -120,7 +126,12 @@ class AnnotationReplyActionApi(Resource): args = AnnotationReplyPayload.model_validate(console_ns.payload) match action: case "enable": - result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) + enable_args: EnableAnnotationArgs = { + "score_threshold": args.score_threshold, + "embedding_provider_name": args.embedding_provider_name, + "embedding_model_name": args.embedding_model_name, + } + result = AppAnnotationService.enable_app_annotation(enable_args, app_id) case "disable": result = AppAnnotationService.disable_app_annotation(app_id) return result, 200 @@ -161,7 +172,8 @@ class AppAnnotationSettingUpdateApi(Resource): args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload) - result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump()) + setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold} + result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args) return result, 200 @@ -237,8 +249,16 @@ class AnnotationApi(Resource): def post(self, app_id): app_id = str(app_id) args = CreateAnnotationPayload.model_validate(console_ns.payload) - data = args.model_dump(exclude_none=True) - annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id) + upsert_args: UpsertAnnotationArgs = {} + if args.answer is not None: + upsert_args["answer"] = args.answer + if args.content is not None: + upsert_args["content"] = args.content + if args.message_id is not None: + upsert_args["message_id"] = args.message_id + if args.question is not None: + upsert_args["question"] = args.question + annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @@ -315,9 +335,12 @@ class AnnotationUpdateDeleteApi(Resource): app_id = str(app_id) annotation_id = str(annotation_id) args = UpdateAnnotationPayload.model_validate(console_ns.payload) - annotation = AppAnnotationService.update_app_annotation_directly( - args.model_dump(exclude_none=True), app_id, annotation_id - ) + update_args: UpdateAnnotationArgs = {} + if args.answer is not None: + update_args["answer"] = args.answer + if args.question is not None: + update_args["question"] = args.question + annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index c22190cbc9..00bb9aa463 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -12,7 +12,12 @@ from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client from fields.annotation_fields import Annotation, AnnotationList from models.model import App -from services.annotation_service import AppAnnotationService +from services.annotation_service import ( + AppAnnotationService, + EnableAnnotationArgs, + InsertAnnotationArgs, + UpdateAnnotationArgs, +) class AnnotationCreatePayload(BaseModel): @@ -46,10 +51,15 @@ class AnnotationReplyActionApi(Resource): @validate_app_token def post(self, app_model: App, action: Literal["enable", "disable"]): """Enable or disable annotation reply feature.""" - args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump() + payload = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}) match action: case "enable": - result = AppAnnotationService.enable_app_annotation(args, app_model.id) + enable_args: EnableAnnotationArgs = { + "score_threshold": payload.score_threshold, + "embedding_provider_name": payload.embedding_provider_name, + "embedding_model_name": payload.embedding_model_name, + } + result = AppAnnotationService.enable_app_annotation(enable_args, app_model.id) case "disable": result = AppAnnotationService.disable_app_annotation(app_model.id) return result, 200 @@ -135,8 +145,9 @@ class AnnotationListApi(Resource): @validate_app_token def post(self, app_model: App): """Create a new annotation.""" - args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() - annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) + payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}) + insert_args: InsertAnnotationArgs = {"question": payload.question, "answer": payload.answer} + annotation = AppAnnotationService.insert_app_annotation_directly(insert_args, app_model.id) response = Annotation.model_validate(annotation, from_attributes=True) return response.model_dump(mode="json"), HTTPStatus.CREATED @@ -164,8 +175,9 @@ class AnnotationUpdateDeleteApi(Resource): @edit_permission_required def put(self, app_model: App, annotation_id: str): """Update an existing annotation.""" - args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() - annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) + payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}) + update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer} + annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id) response = Annotation.model_validate(annotation, from_attributes=True) return response.model_dump(mode="json") diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index ae5facbec0..ff0882ad5c 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,11 +1,8 @@ import logging import uuid - -import pandas as pd - -logger = logging.getLogger(__name__) from typing import TypedDict +import pandas as pd from sqlalchemy import delete, or_, select, update from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound @@ -24,6 +21,8 @@ from tasks.annotation.disable_annotation_reply_task import disable_annotation_re from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task +logger = logging.getLogger(__name__) + class AnnotationJobStatusDict(TypedDict): job_id: str @@ -46,9 +45,50 @@ class AnnotationSettingDisabledDict(TypedDict): enabled: bool +class EnableAnnotationArgs(TypedDict): + """Expected shape of the args dict passed to enable_app_annotation.""" + + score_threshold: float + embedding_provider_name: str + embedding_model_name: str + + +class UpsertAnnotationArgs(TypedDict, total=False): + """Expected shape of the args dict passed to up_insert_app_annotation_from_message.""" + + answer: str + content: str + message_id: str + question: str + + +class InsertAnnotationArgs(TypedDict): + """Expected shape of the args dict passed to insert_app_annotation_directly.""" + + question: str + answer: str + + +class UpdateAnnotationArgs(TypedDict, total=False): + """Expected shape of the args dict passed to update_app_annotation_directly. + + Both fields are optional at the type level; the service validates at runtime + and raises ValueError if either is missing. + """ + + answer: str + question: str + + +class UpdateAnnotationSettingArgs(TypedDict): + """Expected shape of the args dict passed to update_app_annotation_setting.""" + + score_threshold: float + + class AppAnnotationService: @classmethod - def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: + def up_insert_app_annotation_from_message(cls, args: UpsertAnnotationArgs, app_id: str) -> MessageAnnotation: # get app info current_user, current_tenant_id = current_account_with_tenant() app = db.session.scalar( @@ -62,8 +102,9 @@ class AppAnnotationService: if answer is None: raise ValueError("Either 'answer' or 'content' must be provided") - if args.get("message_id"): - message_id = str(args["message_id"]) + raw_message_id = args.get("message_id") + if raw_message_id: + message_id = str(raw_message_id) message = db.session.scalar( select(Message).where(Message.id == message_id, Message.app_id == app.id).limit(1) ) @@ -87,9 +128,10 @@ class AppAnnotationService: account_id=current_user.id, ) else: - question = args.get("question") - if not question: + maybe_question = args.get("question") + if not maybe_question: raise ValueError("'question' is required when 'message_id' is not provided") + question = maybe_question annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id) db.session.add(annotation) @@ -110,7 +152,7 @@ class AppAnnotationService: return annotation @classmethod - def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict: + def enable_app_annotation(cls, args: EnableAnnotationArgs, app_id: str) -> AnnotationJobStatusDict: enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" cache_result = redis_client.get(enable_app_annotation_key) if cache_result is not None: @@ -217,7 +259,7 @@ class AppAnnotationService: return annotations @classmethod - def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: + def insert_app_annotation_directly(cls, args: InsertAnnotationArgs, app_id: str) -> MessageAnnotation: # get app info current_user, current_tenant_id = current_account_with_tenant() app = db.session.scalar( @@ -251,7 +293,7 @@ class AppAnnotationService: return annotation @classmethod - def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): + def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str): # get app info _, current_tenant_id = current_account_with_tenant() app = db.session.scalar( @@ -270,7 +312,11 @@ class AppAnnotationService: if question is None: raise ValueError("'question' is required") - annotation.content = args["answer"] + answer = args.get("answer") + if answer is None: + raise ValueError("'answer' is required") + + annotation.content = answer annotation.question = question db.session.commit() @@ -613,7 +659,7 @@ class AppAnnotationService: @classmethod def update_app_annotation_setting( - cls, app_id: str, annotation_setting_id: str, args: dict + cls, app_id: str, annotation_setting_id: str, args: UpdateAnnotationSettingArgs ) -> AnnotationSettingDict: current_user, current_tenant_id = current_account_with_tenant() # get app info