mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:00:26 -04:00
refactor: replace bare dict with dict[str, Any] in provider entities and plugin client (#35077)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
@@ -37,7 +39,7 @@ class PipelineDocument(BaseModel):
|
||||
id: str
|
||||
position: int
|
||||
data_source_type: str
|
||||
data_source_info: dict | None = None
|
||||
data_source_info: dict[str, Any] | None = None
|
||||
name: str
|
||||
indexing_status: str
|
||||
error: str | None = None
|
||||
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
@@ -111,7 +112,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
|
||||
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get current credentials.
|
||||
|
||||
@@ -233,7 +234,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
|
||||
def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a specific provider credential by ID.
|
||||
:param credential_id: Credential ID
|
||||
@@ -297,7 +298,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = stmt.where(ProviderCredential.id != exclude_id)
|
||||
return session.execute(stmt).scalar_one_or_none() is not None
|
||||
|
||||
def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
|
||||
def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get provider credentials.
|
||||
|
||||
@@ -317,7 +318,9 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
|
||||
def validate_provider_credentials(
|
||||
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
|
||||
):
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
@@ -447,7 +450,7 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
return provider_names
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str | None):
|
||||
def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None):
|
||||
"""
|
||||
Add custom provider credentials.
|
||||
:param credentials: provider credentials
|
||||
@@ -515,7 +518,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def update_provider_credential(
|
||||
self,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str,
|
||||
credential_name: str | None,
|
||||
):
|
||||
@@ -760,7 +763,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def _get_specific_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credential_id: str
|
||||
) -> dict | None:
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a specific provider credential by ID.
|
||||
:param credential_id: Credential ID
|
||||
@@ -832,7 +835,9 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
|
||||
return session.execute(stmt).scalar_one_or_none() is not None
|
||||
|
||||
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None:
|
||||
def get_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credential_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get custom model credentials.
|
||||
|
||||
@@ -872,7 +877,7 @@ class ProviderConfiguration(BaseModel):
|
||||
self,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str = "",
|
||||
session: Session | None = None,
|
||||
):
|
||||
@@ -939,7 +944,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return _validate(new_session)
|
||||
|
||||
def create_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Create a custom model credential.
|
||||
@@ -1002,7 +1007,12 @@ class ProviderConfiguration(BaseModel):
|
||||
raise
|
||||
|
||||
def update_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
|
||||
self,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
credential_name: str | None,
|
||||
credential_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Update a custom model credential.
|
||||
@@ -1412,7 +1422,9 @@ class ProviderConfiguration(BaseModel):
|
||||
# Get model instance of LLM
|
||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||
|
||||
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
|
||||
def get_model_schema(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
@@ -1471,7 +1483,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return secret_input_form_variables
|
||||
|
||||
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
|
||||
def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]):
|
||||
"""
|
||||
Obfuscated credentials.
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum, auto
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel):
|
||||
enabled: bool
|
||||
current_quota_type: ProviderQuotaType | None = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
credentials: dict | None = None
|
||||
credentials: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class CustomProviderConfiguration(BaseModel):
|
||||
@@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel):
|
||||
Model class for provider custom configuration.
|
||||
"""
|
||||
|
||||
credentials: dict
|
||||
credentials: dict[str, Any]
|
||||
current_credential_id: str | None = None
|
||||
current_credential_name: str | None = None
|
||||
available_credentials: list[CredentialConfiguration] = []
|
||||
@@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel):
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict | None
|
||||
credentials: dict[str, Any] | None
|
||||
current_credential_id: str | None = None
|
||||
current_credential_name: str | None = None
|
||||
available_model_credentials: list[CredentialConfiguration] = []
|
||||
|
||||
@@ -50,7 +50,7 @@ class PluginModelClient(BasePluginClient):
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
@@ -118,7 +118,7 @@ class PluginModelClient(BasePluginClient):
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
@@ -157,9 +157,9 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
@@ -206,7 +206,7 @@ class PluginModelClient(BasePluginClient):
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
@@ -248,7 +248,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
input_type: str,
|
||||
) -> EmbeddingResult:
|
||||
@@ -290,7 +290,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
documents: list[dict],
|
||||
input_type: str,
|
||||
) -> EmbeddingResult:
|
||||
@@ -332,7 +332,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
) -> list[int]:
|
||||
"""
|
||||
@@ -372,7 +372,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: float | None = None,
|
||||
@@ -418,7 +418,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None = None,
|
||||
@@ -463,7 +463,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
content_text: str,
|
||||
voice: str,
|
||||
) -> Generator[bytes, None, None]:
|
||||
@@ -508,7 +508,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
language: str | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -552,7 +552,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
file: IO[bytes],
|
||||
) -> str:
|
||||
"""
|
||||
@@ -592,7 +592,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
text: str,
|
||||
) -> bool:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user