refactor: replace bare dict with dict[str, Any] in provider entities and plugin client (#35077)

This commit is contained in:
wdeveloper16
2026-04-13 19:09:25 +02:00
committed by GitHub
parent 7056d2ae99
commit 4e0273bb28
4 changed files with 46 additions and 32 deletions

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field, field_validator
@@ -37,7 +39,7 @@ class PipelineDocument(BaseModel):
id: str
position: int
data_source_type: str
data_source_info: dict | None = None
data_source_info: dict[str, Any] | None = None
name: str
indexing_status: str
error: str | None = None

View File

@@ -6,6 +6,7 @@ import re
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from typing import Any
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import (
@@ -111,7 +112,7 @@ class ProviderConfiguration(BaseModel):
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
"""
Get current credentials.
@@ -233,7 +234,7 @@ class ProviderConfiguration(BaseModel):
return session.execute(stmt).scalar_one_or_none()
def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None:
"""
Get a specific provider credential by ID.
:param credential_id: Credential ID
@@ -297,7 +298,7 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None
def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None:
"""
Get provider credentials.
@@ -317,7 +318,9 @@ class ProviderConfiguration(BaseModel):
else [],
)
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
def validate_provider_credentials(
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
):
"""
Validate custom credentials.
:param credentials: provider credentials
@@ -447,7 +450,7 @@ class ProviderConfiguration(BaseModel):
provider_names.append(model_provider_id.provider_name)
return provider_names
def create_provider_credential(self, credentials: dict, credential_name: str | None):
def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None):
"""
Add custom provider credentials.
:param credentials: provider credentials
@@ -515,7 +518,7 @@ class ProviderConfiguration(BaseModel):
def update_provider_credential(
self,
credentials: dict,
credentials: dict[str, Any],
credential_id: str,
credential_name: str | None,
):
@@ -760,7 +763,7 @@ class ProviderConfiguration(BaseModel):
def _get_specific_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str
) -> dict | None:
) -> dict[str, Any] | None:
"""
Get a specific provider credential by ID.
:param credential_id: Credential ID
@@ -832,7 +835,9 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None:
def get_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str | None
) -> dict[str, Any] | None:
"""
Get custom model credentials.
@@ -872,7 +877,7 @@ class ProviderConfiguration(BaseModel):
self,
model_type: ModelType,
model: str,
credentials: dict,
credentials: dict[str, Any],
credential_id: str = "",
session: Session | None = None,
):
@@ -939,7 +944,7 @@ class ProviderConfiguration(BaseModel):
return _validate(new_session)
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
) -> None:
"""
Create a custom model credential.
@@ -1002,7 +1007,12 @@ class ProviderConfiguration(BaseModel):
raise
def update_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
self,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
credential_name: str | None,
credential_id: str,
) -> None:
"""
Update a custom model credential.
@@ -1412,7 +1422,9 @@ class ProviderConfiguration(BaseModel):
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
def get_model_schema(
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
) -> AIModelEntity | None:
"""
Get model schema
"""
@@ -1471,7 +1483,7 @@ class ProviderConfiguration(BaseModel):
return secret_input_form_variables
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]):
"""
Obfuscated credentials.

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from enum import StrEnum, auto
from typing import Union
from typing import Any, Union
from graphon.model_runtime.entities.model_entities import ModelType
from pydantic import BaseModel, ConfigDict, Field
@@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel):
enabled: bool
current_quota_type: ProviderQuotaType | None = None
quota_configurations: list[QuotaConfiguration] = []
credentials: dict | None = None
credentials: dict[str, Any] | None = None
class CustomProviderConfiguration(BaseModel):
@@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel):
Model class for provider custom configuration.
"""
credentials: dict
credentials: dict[str, Any]
current_credential_id: str | None = None
current_credential_name: str | None = None
available_credentials: list[CredentialConfiguration] = []
@@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel):
model: str
model_type: ModelType
credentials: dict | None
credentials: dict[str, Any] | None
current_credential_id: str | None = None
current_credential_name: str | None = None
available_model_credentials: list[CredentialConfiguration] = []

View File

@@ -50,7 +50,7 @@ class PluginModelClient(BasePluginClient):
provider: str,
model_type: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
) -> AIModelEntity | None:
"""
Get model schema
@@ -118,7 +118,7 @@ class PluginModelClient(BasePluginClient):
provider: str,
model_type: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
) -> bool:
"""
validate the credentials of the provider
@@ -157,9 +157,9 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
prompt_messages: list[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
@@ -206,7 +206,7 @@ class PluginModelClient(BasePluginClient):
provider: str,
model_type: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
@@ -248,7 +248,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
texts: list[str],
input_type: str,
) -> EmbeddingResult:
@@ -290,7 +290,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
documents: list[dict],
input_type: str,
) -> EmbeddingResult:
@@ -332,7 +332,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
texts: list[str],
) -> list[int]:
"""
@@ -372,7 +372,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
query: str,
docs: list[str],
score_threshold: float | None = None,
@@ -418,7 +418,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None = None,
@@ -463,7 +463,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
content_text: str,
voice: str,
) -> Generator[bytes, None, None]:
@@ -508,7 +508,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
language: str | None = None,
):
"""
@@ -552,7 +552,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
file: IO[bytes],
) -> str:
"""
@@ -592,7 +592,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
text: str,
) -> bool:
"""