mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 09:00:35 -04:00
refactor: replace bare dict with dict[str, Any] in core provider services and misc modules (#35124)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""Configuration for InterSystems IRIS vector database."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, PositiveInt, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@@ -64,7 +66,7 @@ class IrisVectorConfig(BaseSettings):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate IRIS configuration values.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -145,7 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
|
||||
|
||||
id: str
|
||||
name: str
|
||||
credentials: dict
|
||||
credentials: dict[str, Any]
|
||||
credential_source_type: str | None = None
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
@@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension
|
||||
|
||||
|
||||
class ExternalDataToolFactory:
|
||||
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict):
|
||||
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict[str, Any]):
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
self.__extension_instance = extension_class(
|
||||
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict):
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ class ModelInstance:
|
||||
|
||||
@staticmethod
|
||||
def _get_load_balancing_manager(
|
||||
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict
|
||||
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict[str, Any]
|
||||
) -> Optional["LBModelManager"]:
|
||||
"""
|
||||
Get load balancing model credentials
|
||||
|
||||
@@ -96,11 +96,11 @@ class SimplePromptTransform(PromptTransform):
|
||||
app_mode: AppMode,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
query: str | None = None,
|
||||
context: str | None = None,
|
||||
histories: str | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
# get prompt template
|
||||
prompt_template_config = self.get_prompt_template(
|
||||
app_mode=app_mode,
|
||||
@@ -187,7 +187,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
query: str,
|
||||
context: str | None,
|
||||
files: Sequence["File"],
|
||||
@@ -234,7 +234,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
query: str,
|
||||
context: str | None,
|
||||
files: Sequence["File"],
|
||||
|
||||
@@ -856,7 +856,7 @@ class ProviderManager:
|
||||
secret_variables: list[str],
|
||||
cache_type: ProviderCredentialsCacheType,
|
||||
is_provider: bool = False,
|
||||
) -> dict:
|
||||
) -> dict[str, Any]:
|
||||
"""Get and decrypt credentials with caching."""
|
||||
credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -174,8 +174,8 @@ class RetrievalService:
|
||||
cls,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrieval_model: dict | None = None,
|
||||
metadata_filtering_conditions: dict | None = None,
|
||||
external_retrieval_model: dict[str, Any] | None = None,
|
||||
metadata_filtering_conditions: dict[str, Any] | None = None,
|
||||
):
|
||||
stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
|
||||
@@ -232,7 +232,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return embedding_results # type: ignore
|
||||
|
||||
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
"""Embed multimodal documents."""
|
||||
# use doc embedding cache or store if not exists
|
||||
file_id = multimodel_document["file_id"]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
@@ -20,7 +21,7 @@ class Embeddings(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
"""Embed multimodal query."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ def _get_case_routing() -> dict[TelemetryCase, CaseRoute]:
|
||||
return _case_routing
|
||||
|
||||
|
||||
def __getattr__(name: str) -> dict:
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Lazy module-level access to routing tables."""
|
||||
if name == "CASE_ROUTING":
|
||||
return _get_case_routing()
|
||||
|
||||
@@ -10,6 +10,7 @@ import tempfile
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import clickzetta
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -39,7 +40,7 @@ class ClickZettaVolumeConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
"""Validate the configuration values.
|
||||
|
||||
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
@@ -6,7 +8,7 @@ from models.model import AppMode, AppModelConfigDict
|
||||
|
||||
class AppModelConfigService:
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict:
|
||||
def validate_configuration(cls, tenant_id: str, config: dict[str, Any], app_mode: AppMode) -> AppModelConfigDict:
|
||||
match app_mode:
|
||||
case AppMode.CHAT:
|
||||
return ChatAppConfigManager.config_validate(tenant_id, config)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -19,7 +20,7 @@ class ApiKeyAuthService:
|
||||
return data_source_api_key_bindings
|
||||
|
||||
@staticmethod
|
||||
def create_provider_auth(tenant_id: str, args: dict):
|
||||
def create_provider_auth(tenant_id: str, args: dict[str, Any]):
|
||||
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
|
||||
if auth_result:
|
||||
# Encrypt the api key
|
||||
|
||||
@@ -428,7 +428,7 @@ class ToolTransformService:
|
||||
|
||||
@staticmethod
|
||||
def convert_builtin_provider_to_credential_entity(
|
||||
provider: BuiltinToolProvider, credentials: dict
|
||||
provider: BuiltinToolProvider, credentials: dict[str, Any]
|
||||
) -> ToolProviderCredentialApiEntity:
|
||||
return ToolProviderCredentialApiEntity(
|
||||
id=provider.id,
|
||||
|
||||
Reference in New Issue
Block a user