refactor: replace bare dict with TypedDicts in annotation_service (#34998)

This commit is contained in:
wdeveloper16
2026-04-13 05:46:33 +02:00
committed by GitHub
parent 17da0e4146
commit 8436470fcb
3 changed files with 110 additions and 29 deletions

View File

@@ -25,7 +25,13 @@ from fields.annotation_fields import (
)
from libs.helper import uuid_value
from libs.login import login_required
from services.annotation_service import AppAnnotationService
from services.annotation_service import (
AppAnnotationService,
EnableAnnotationArgs,
UpdateAnnotationArgs,
UpdateAnnotationSettingArgs,
UpsertAnnotationArgs,
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -120,7 +126,12 @@ class AnnotationReplyActionApi(Resource):
args = AnnotationReplyPayload.model_validate(console_ns.payload)
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
enable_args: EnableAnnotationArgs = {
"score_threshold": args.score_threshold,
"embedding_provider_name": args.embedding_provider_name,
"embedding_model_name": args.embedding_model_name,
}
result = AppAnnotationService.enable_app_annotation(enable_args, app_id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@@ -161,7 +172,8 @@ class AppAnnotationSettingUpdateApi(Resource):
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args)
return result, 200
@@ -237,8 +249,16 @@ class AnnotationApi(Resource):
def post(self, app_id):
app_id = str(app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload)
data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
upsert_args: UpsertAnnotationArgs = {}
if args.answer is not None:
upsert_args["answer"] = args.answer
if args.content is not None:
upsert_args["content"] = args.content
if args.message_id is not None:
upsert_args["message_id"] = args.message_id
if args.question is not None:
upsert_args["question"] = args.question
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@@ -315,9 +335,12 @@ class AnnotationUpdateDeleteApi(Resource):
app_id = str(app_id)
annotation_id = str(annotation_id)
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
update_args: UpdateAnnotationArgs = {}
if args.answer is not None:
update_args["answer"] = args.answer
if args.question is not None:
update_args["question"] = args.question
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required

View File

@@ -12,7 +12,12 @@ from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import Annotation, AnnotationList
from models.model import App
from services.annotation_service import AppAnnotationService
from services.annotation_service import (
AppAnnotationService,
EnableAnnotationArgs,
InsertAnnotationArgs,
UpdateAnnotationArgs,
)
class AnnotationCreatePayload(BaseModel):
@@ -46,10 +51,15 @@ class AnnotationReplyActionApi(Resource):
@validate_app_token
def post(self, app_model: App, action: Literal["enable", "disable"]):
"""Enable or disable annotation reply feature."""
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
payload = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {})
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
enable_args: EnableAnnotationArgs = {
"score_threshold": payload.score_threshold,
"embedding_provider_name": payload.embedding_provider_name,
"embedding_model_name": payload.embedding_model_name,
}
result = AppAnnotationService.enable_app_annotation(enable_args, app_model.id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
return result, 200
@@ -135,8 +145,9 @@ class AnnotationListApi(Resource):
@validate_app_token
def post(self, app_model: App):
"""Create a new annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
insert_args: InsertAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.insert_app_annotation_directly(insert_args, app_model.id)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json"), HTTPStatus.CREATED
@@ -164,8 +175,9 @@ class AnnotationUpdateDeleteApi(Resource):
@edit_permission_required
def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")

View File

@@ -1,11 +1,8 @@
import logging
import uuid
import pandas as pd
logger = logging.getLogger(__name__)
from typing import TypedDict
import pandas as pd
from sqlalchemy import delete, or_, select, update
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@@ -24,6 +21,8 @@ from tasks.annotation.disable_annotation_reply_task import disable_annotation_re
from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
logger = logging.getLogger(__name__)
class AnnotationJobStatusDict(TypedDict):
job_id: str
@@ -46,9 +45,50 @@ class AnnotationSettingDisabledDict(TypedDict):
enabled: bool
class EnableAnnotationArgs(TypedDict):
"""Expected shape of the args dict passed to enable_app_annotation."""
score_threshold: float
embedding_provider_name: str
embedding_model_name: str
class UpsertAnnotationArgs(TypedDict, total=False):
"""Expected shape of the args dict passed to up_insert_app_annotation_from_message."""
answer: str
content: str
message_id: str
question: str
class InsertAnnotationArgs(TypedDict):
"""Expected shape of the args dict passed to insert_app_annotation_directly."""
question: str
answer: str
class UpdateAnnotationArgs(TypedDict, total=False):
"""Expected shape of the args dict passed to update_app_annotation_directly.
Both fields are optional at the type level; the service validates at runtime
and raises ValueError if either is missing.
"""
answer: str
question: str
class UpdateAnnotationSettingArgs(TypedDict):
"""Expected shape of the args dict passed to update_app_annotation_setting."""
score_threshold: float
class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
def up_insert_app_annotation_from_message(cls, args: UpsertAnnotationArgs, app_id: str) -> MessageAnnotation:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
@@ -62,8 +102,9 @@ class AppAnnotationService:
if answer is None:
raise ValueError("Either 'answer' or 'content' must be provided")
if args.get("message_id"):
message_id = str(args["message_id"])
raw_message_id = args.get("message_id")
if raw_message_id:
message_id = str(raw_message_id)
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app.id).limit(1)
)
@@ -87,9 +128,10 @@ class AppAnnotationService:
account_id=current_user.id,
)
else:
question = args.get("question")
if not question:
maybe_question = args.get("question")
if not maybe_question:
raise ValueError("'question' is required when 'message_id' is not provided")
question = maybe_question
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id)
db.session.add(annotation)
@@ -110,7 +152,7 @@ class AppAnnotationService:
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict:
def enable_app_annotation(cls, args: EnableAnnotationArgs, app_id: str) -> AnnotationJobStatusDict:
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
@@ -217,7 +259,7 @@ class AppAnnotationService:
return annotations
@classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
def insert_app_annotation_directly(cls, args: InsertAnnotationArgs, app_id: str) -> MessageAnnotation:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
@@ -251,7 +293,7 @@ class AppAnnotationService:
return annotation
@classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
@@ -270,7 +312,11 @@ class AppAnnotationService:
if question is None:
raise ValueError("'question' is required")
annotation.content = args["answer"]
answer = args.get("answer")
if answer is None:
raise ValueError("'answer' is required")
annotation.content = answer
annotation.question = question
db.session.commit()
@@ -613,7 +659,7 @@ class AppAnnotationService:
@classmethod
def update_app_annotation_setting(
cls, app_id: str, annotation_setting_id: str, args: dict
cls, app_id: str, annotation_setting_id: str, args: UpdateAnnotationSettingArgs
) -> AnnotationSettingDict:
current_user, current_tenant_id = current_account_with_tenant()
# get app info