From 4e0273bb28f38dcd86e244b62becb737548099a0 Mon Sep 17 00:00:00 2001 From: wdeveloper16 Date: Mon, 13 Apr 2026 19:09:25 +0200 Subject: [PATCH] refactor: replace bare dict with dict[str, Any] in provider entities and plugin client (#35077) --- api/core/entities/knowledge_entities.py | 4 ++- api/core/entities/provider_configuration.py | 38 ++++++++++++++------- api/core/entities/provider_entities.py | 8 ++--- api/core/plugin/impl/model.py | 28 +++++++-------- 4 files changed, 46 insertions(+), 32 deletions(-) diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index b1ba3c3e2a..a13938f3fb 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel, Field, field_validator @@ -37,7 +39,7 @@ class PipelineDocument(BaseModel): id: str position: int data_source_type: str - data_source_info: dict | None = None + data_source_info: dict[str, Any] | None = None name: str indexing_status: str error: str | None = None diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index f3b2c31465..d07f6f913a 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -6,6 +6,7 @@ import re from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError +from typing import Any from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.provider_entities import ( @@ -111,7 +112,7 @@ class ProviderConfiguration(BaseModel): return ModelProviderFactory(model_runtime=self._bound_model_runtime) return create_plugin_model_provider_factory(tenant_id=self.tenant_id) - def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: + def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None: """ Get current credentials. @@ -233,7 +234,7 @@ class ProviderConfiguration(BaseModel): return session.execute(stmt).scalar_one_or_none() - def _get_specific_provider_credential(self, credential_id: str) -> dict | None: + def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None: """ Get a specific provider credential by ID. :param credential_id: Credential ID @@ -297,7 +298,7 @@ class ProviderConfiguration(BaseModel): stmt = stmt.where(ProviderCredential.id != exclude_id) return session.execute(stmt).scalar_one_or_none() is not None - def get_provider_credential(self, credential_id: str | None = None) -> dict | None: + def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None: """ Get provider credentials. @@ -317,7 +318,9 @@ class ProviderConfiguration(BaseModel): else [], ) - def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None): + def validate_provider_credentials( + self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None + ): """ Validate custom credentials. :param credentials: provider credentials @@ -447,7 +450,7 @@ class ProviderConfiguration(BaseModel): provider_names.append(model_provider_id.provider_name) return provider_names - def create_provider_credential(self, credentials: dict, credential_name: str | None): + def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None): """ Add custom provider credentials. :param credentials: provider credentials @@ -515,7 +518,7 @@ class ProviderConfiguration(BaseModel): def update_provider_credential( self, - credentials: dict, + credentials: dict[str, Any], credential_id: str, credential_name: str | None, ): @@ -760,7 +763,7 @@ class ProviderConfiguration(BaseModel): def _get_specific_custom_model_credential( self, model_type: ModelType, model: str, credential_id: str - ) -> dict | None: + ) -> dict[str, Any] | None: """ Get a specific provider credential by ID. :param credential_id: Credential ID @@ -832,7 +835,9 @@ class ProviderConfiguration(BaseModel): stmt = stmt.where(ProviderModelCredential.id != exclude_id) return session.execute(stmt).scalar_one_or_none() is not None - def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None: + def get_custom_model_credential( + self, model_type: ModelType, model: str, credential_id: str | None + ) -> dict[str, Any] | None: """ Get custom model credentials. @@ -872,7 +877,7 @@ class ProviderConfiguration(BaseModel): self, model_type: ModelType, model: str, - credentials: dict, + credentials: dict[str, Any], credential_id: str = "", session: Session | None = None, ): @@ -939,7 +944,7 @@ class ProviderConfiguration(BaseModel): return _validate(new_session) def create_custom_model_credential( - self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None + self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None ) -> None: """ Create a custom model credential. @@ -1002,7 +1007,12 @@ class ProviderConfiguration(BaseModel): raise def update_custom_model_credential( - self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str + self, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + credential_name: str | None, + credential_id: str, ) -> None: """ Update a custom model credential. @@ -1412,7 +1422,9 @@ class ProviderConfiguration(BaseModel): # Get model instance of LLM return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) - def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None: + def get_model_schema( + self, model_type: ModelType, model: str, credentials: dict[str, Any] | None + ) -> AIModelEntity | None: """ Get model schema """ @@ -1471,7 +1483,7 @@ class ProviderConfiguration(BaseModel): return secret_input_form_variables - def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]): + def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]): """ Obfuscated credentials. diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 2c8767a32b..5da88c0beb 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import StrEnum, auto -from typing import Union +from typing import Any, Union from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, ConfigDict, Field @@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel): enabled: bool current_quota_type: ProviderQuotaType | None = None quota_configurations: list[QuotaConfiguration] = [] - credentials: dict | None = None + credentials: dict[str, Any] | None = None class CustomProviderConfiguration(BaseModel): @@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel): Model class for provider custom configuration. """ - credentials: dict + credentials: dict[str, Any] current_credential_id: str | None = None current_credential_name: str | None = None available_credentials: list[CredentialConfiguration] = [] @@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel): model: str model_type: ModelType - credentials: dict | None + credentials: dict[str, Any] | None current_credential_id: str | None = None current_credential_name: str | None = None available_model_credentials: list[CredentialConfiguration] = [] diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 1e38c24717..e54bebd7ac 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -50,7 +50,7 @@ class PluginModelClient(BasePluginClient): provider: str, model_type: str, model: str, - credentials: dict, + credentials: dict[str, Any], ) -> AIModelEntity | None: """ Get model schema @@ -118,7 +118,7 @@ class PluginModelClient(BasePluginClient): provider: str, model_type: str, model: str, - credentials: dict, + credentials: dict[str, Any], ) -> bool: """ validate the credentials of the provider @@ -157,9 +157,9 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], prompt_messages: list[PromptMessage], - model_parameters: dict | None = None, + model_parameters: dict[str, Any] | None = None, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, @@ -206,7 +206,7 @@ class PluginModelClient(BasePluginClient): provider: str, model_type: str, model: str, - credentials: dict, + credentials: dict[str, Any], prompt_messages: list[PromptMessage], tools: list[PromptMessageTool] | None = None, ) -> int: @@ -248,7 +248,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], texts: list[str], input_type: str, ) -> EmbeddingResult: @@ -290,7 +290,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], documents: list[dict], input_type: str, ) -> EmbeddingResult: @@ -332,7 +332,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], texts: list[str], ) -> list[int]: """ @@ -372,7 +372,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], query: str, docs: list[str], score_threshold: float | None = None, @@ -418,7 +418,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], query: MultimodalRerankInput, docs: list[MultimodalRerankInput], score_threshold: float | None = None, @@ -463,7 +463,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], content_text: str, voice: str, ) -> Generator[bytes, None, None]: @@ -508,7 +508,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], language: str | None = None, ): """ @@ -552,7 +552,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], file: IO[bytes], ) -> str: """ @@ -592,7 +592,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], text: str, ) -> bool: """