diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 3f5331a443..d4bfe874ff 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,7 +1,7 @@ import logging from collections.abc import Mapping from datetime import datetime -from typing import Literal +from typing import Literal, override from dateutil.parser import isoparse from flask import request @@ -76,11 +76,13 @@ def _enum_value(value): class WorkflowRunStatusField(fields.Raw): + @override def output(self, key, obj: WorkflowRun, **kwargs): return _enum_value(obj.status) class WorkflowRunOutputsField(fields.Raw): + @override def output(self, key, obj: WorkflowRun, **kwargs): status = _enum_value(obj.status) if status == WorkflowExecutionStatus.PAUSED.value: diff --git a/api/core/app/apps/draft_variable_saver.py b/api/core/app/apps/draft_variable_saver.py index 24018012c5..0048989e79 100644 --- a/api/core/app/apps/draft_variable_saver.py +++ b/api/core/app/apps/draft_variable_saver.py @@ -2,7 +2,7 @@ from __future__ import annotations import abc from collections.abc import Mapping -from typing import Any, Protocol +from typing import Any, Protocol, override from graphon.enums import NodeType @@ -29,5 +29,6 @@ class DraftVariableSaverFactory(Protocol): class NoopDraftVariableSaver(DraftVariableSaver): + @override def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: return None diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 453dd41957..91c46d07a8 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -6,7 +6,7 @@ import re from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError -from typing import Any +from typing import Any, override from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select @@ -1889,6 +1889,7 @@ class ProviderConfigurations(BaseModel): key = str(ModelProviderID(key)) return key in self.configurations + @override def __iter__(self): # Return an iterator of (key, value) tuples to match BaseModel's __iter__ yield from self.configurations.items() diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 8ce068cfbb..30fcbc230f 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, TypedDict +from typing import Any, TypedDict, override from sqlalchemy import select @@ -29,6 +29,7 @@ class ApiExternalDataTool(ExternalDataTool): """the unique name of external data tool""" @classmethod + @override def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -50,6 +51,7 @@ class ApiExternalDataTool(ExternalDataTool): if not api_based_extension: raise ValueError("api_based_extension_id is invalid") + @override def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/mcp/auth_client.py b/api/core/mcp/auth_client.py index 173913196e..64596969ef 100644 --- a/api/core/mcp/auth_client.py +++ b/api/core/mcp/auth_client.py @@ -7,7 +7,7 @@ authentication failures and retries operations after refreshing tokens. import logging from collections.abc import Callable -from typing import Any +from typing import Any, override from sqlalchemy.orm import Session @@ -159,6 +159,7 @@ class MCPClientWithAuthRetry(MCPClient): # Reset retry flag after operation completes self._has_retried = False + @override def __enter__(self): """Enter the context manager with retry support.""" @@ -168,6 +169,7 @@ class MCPClientWithAuthRetry(MCPClient): return self._execute_with_retry(initialize_with_retry) + @override def list_tools(self) -> list[Tool]: """ List available tools from the MCP server with auth retry. @@ -180,6 +182,7 @@ class MCPClientWithAuthRetry(MCPClient): """ return self._execute_with_retry(super().list_tools) + @override def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult: """ Invoke a tool on the MCP server with auth retry. diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index d684fe0dd7..f91295a432 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -1,6 +1,6 @@ import queue from datetime import timedelta -from typing import Any, Protocol +from typing import Any, Protocol, override from pydantic import AnyUrl, TypeAdapter @@ -159,6 +159,7 @@ class ClientSession( types.EmptyResult, ) + @override def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None): """Send a progress notification.""" self.send_notification( @@ -326,6 +327,7 @@ class ClientSession( ) ) + @override def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]): ctx = RequestContext[ClientSession, Any]( request_id=responder.request_id, @@ -351,6 +353,7 @@ class ClientSession( with responder: return responder.respond(types.ClientResult(root=types.EmptyResult())) + @override def _handle_incoming( self, req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, @@ -358,6 +361,7 @@ class ClientSession( """Handle incoming messages by forwarding to the message handler.""" self._message_handler(req) + @override def _received_notification(self, notification: types.ServerNotification): """Handle notifications from the server.""" # Process specific notification types diff --git a/api/core/plugin/impl/exc.py b/api/core/plugin/impl/exc.py index 4cabdc1732..9a4f51ef12 100644 --- a/api/core/plugin/impl/exc.py +++ b/api/core/plugin/impl/exc.py @@ -1,4 +1,5 @@ from collections.abc import Mapping +from typing import override from pydantic import TypeAdapter @@ -11,6 +12,7 @@ class PluginDaemonError(Exception): def __init__(self, description: str): self.description = description + @override def __str__(self) -> str: # returns the class name and description return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}" diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index b3f174bf78..be83f69c48 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any, TypedDict +from typing import Any, TypedDict, override import orjson from pydantic import BaseModel @@ -29,6 +29,7 @@ class Jieba(BaseKeyword): super().__init__(dataset) self._config = KeywordTableConfig() + @override def create(self, texts: list[Document], **kwargs) -> BaseKeyword: lock_name = f"keyword_indexing_lock_{self.dataset.id}" with redis_client.lock(lock_name, timeout=600): @@ -48,6 +49,7 @@ class Jieba(BaseKeyword): return self + @override def add_texts(self, texts: list[Document], **kwargs): lock_name = f"keyword_indexing_lock_{self.dataset.id}" with redis_client.lock(lock_name, timeout=600): @@ -72,12 +74,14 @@ class Jieba(BaseKeyword): self._save_dataset_keyword_table(keyword_table) + @override def text_exists(self, id: str) -> bool: keyword_table = self._get_dataset_keyword_table() if keyword_table is None: return False return id in set.union(*keyword_table.values()) + @override def delete_by_ids(self, ids: list[str]): lock_name = f"keyword_indexing_lock_{self.dataset.id}" with redis_client.lock(lock_name, timeout=600): @@ -87,6 +91,7 @@ class Jieba(BaseKeyword): self._save_dataset_keyword_table(keyword_table) + @override def search(self, query: str, **kwargs: Any) -> list[Document]: keyword_table = self._get_dataset_keyword_table() @@ -122,6 +127,7 @@ class Jieba(BaseKeyword): return documents + @override def delete(self): lock_name = f"keyword_indexing_lock_{self.dataset.id}" with redis_client.lock(lock_name, timeout=600): diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 1f82f7a081..cd73bb9b1a 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -2,7 +2,7 @@ import base64 import logging import time from abc import ABC, abstractmethod -from typing import Any +from typing import Any, override from sqlalchemy import select @@ -72,21 +72,27 @@ class _LazyEmbeddings(Embeddings): self._real = CacheEmbedding(embedding_model) return self._real + @override def embed_documents(self, texts: list[str]) -> list[list[float]]: return self._ensure().embed_documents(texts) + @override def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: return self._ensure().embed_multimodal_documents(multimodel_documents) + @override def embed_query(self, text: str) -> list[float]: return self._ensure().embed_query(text) + @override def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]: return self._ensure().embed_multimodal_query(multimodel_document) + @override async def aembed_documents(self, texts: list[str]) -> list[list[float]]: return await self._ensure().aembed_documents(texts) + @override async def aembed_query(self, text: str) -> list[float]: return await self._ensure().aembed_query(text)