feat: extract model runtime

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-15 15:34:47 +08:00
parent 3d5a29462e
commit fbb74a4af9
178 changed files with 4343 additions and 2134 deletions

View File

@@ -93,10 +93,14 @@ class ModelCredentialSchema(BaseModel):
class SimpleProviderEntity(BaseModel):
"""
Simple model class for provider.
Simplified provider schema exposed to callers.
`provider` is the canonical runtime identifier. `provider_name` is an optional
compatibility alias for short-name lookups and is empty when no alias exists.
"""
provider: str
provider_name: str = ""
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
@@ -115,10 +119,15 @@ class ProviderHelpEntity(BaseModel):
class ProviderEntity(BaseModel):
"""
Model class for provider.
Runtime-native provider schema.
`provider` is the canonical runtime identifier. `provider_name` is a
compatibility alias for callers that still resolve providers by short name and
is empty when no alias exists.
"""
provider: str
provider_name: str = ""
label: I18nObject
description: I18nObject | None = None
icon_small: I18nObject | None = None
@@ -153,6 +162,7 @@ class ProviderEntity(BaseModel):
"""
return SimpleProviderEntity(
provider=self.provider,
provider_name=self.provider_name,
label=self.label,
icon_small=self.icon_small,
supported_model_types=self.supported_model_types,

View File

@@ -1,6 +1,13 @@
from typing import TypedDict
from pydantic import BaseModel
class MultimodalRerankInput(TypedDict):
content: str
content_type: str
class RerankDocument(BaseModel):
"""
Model class for rerank document.

View File

@@ -1,10 +1,18 @@
from decimal import Decimal
from enum import StrEnum, auto
from pydantic import BaseModel
from dify_graph.model_runtime.entities.model_entities import ModelUsage
class EmbeddingInputType(StrEnum):
"""Embedding request input variants understood by the model runtime."""
DOCUMENT = auto()
QUERY = auto()
class EmbeddingUsage(ModelUsage):
"""
Model class for embedding usage.

View File

@@ -1,12 +1,5 @@
import decimal
import hashlib
import logging
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from redis import RedisError
from configs import dify_config
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from dify_graph.model_runtime.entities.model_entities import (
@@ -17,6 +10,7 @@ from dify_graph.model_runtime.entities.model_entities import (
PriceInfo,
PriceType,
)
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
from dify_graph.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@@ -25,45 +19,61 @@ from dify_graph.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
from dify_graph.model_runtime.runtime import ModelRuntime
class AIModel(BaseModel):
class AIModel:
"""
Base class for all models.
Runtime-facing base class for all model providers.
This stays a regular Python class because instances hold live collaborators
such as the provider schema and runtime adapter rather than user input that
benefits from Pydantic validation. Subclasses must pin ``model_type`` via a
class attribute; the base class is not meant to be instantiated directly.
"""
tenant_id: str = Field(description="Tenant ID")
model_type: ModelType = Field(description="Model type")
plugin_id: str = Field(description="Plugin ID")
provider_name: str = Field(description="Provider")
plugin_model_provider: PluginModelProviderEntity = Field(description="Plugin model provider")
started_at: float = Field(description="Invoke start time", default=0)
model_type: ModelType
provider_schema: ProviderEntity
model_runtime: ModelRuntime
started_at: float
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def __init__(
self,
provider_schema: ProviderEntity,
model_runtime: ModelRuntime,
*,
started_at: float = 0,
) -> None:
if getattr(type(self), "model_type", None) is None:
raise TypeError("AIModel subclasses must define model_type as a class attribute")
self.model_type = type(self).model_type
self.provider_schema = provider_schema
self.model_runtime = model_runtime
self.started_at = started_at
@property
def provider(self) -> str:
return self.provider_schema.provider
@property
def provider_display_name(self) -> str:
return self.provider_schema.label.en_US
@property
def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
Map model invoke error to unified error.
:return: Invoke error mapping
The key is the error type thrown to the caller, and the value contains
runtime-facing exception types that should be normalized to it.
"""
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [InvokeBadRequestError],
PluginDaemonInnerError: [PluginDaemonInnerError],
ValueError: [ValueError],
}
@@ -79,15 +89,18 @@ class AIModel(BaseModel):
if invoke_error == InvokeAuthorizationError:
return InvokeAuthorizationError(
description=(
f"[{self.provider_name}] Incorrect model credentials provided, please check and try again."
f"[{self.provider_display_name}] Incorrect model credentials provided, "
"please check and try again."
)
)
elif isinstance(invoke_error, InvokeError):
return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}")
return InvokeError(
description=f"[{self.provider_display_name}] {invoke_error.description}, {str(error)}"
)
else:
return error
return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}")
return InvokeError(description=f"[{self.provider_display_name}] Error: {str(error)}")
def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo:
"""
@@ -144,65 +157,13 @@ class AIModel(BaseModel):
:param credentials: model credentials
:return: model schema
"""
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
cached_schema_json = None
try:
cached_schema_json = redis_client.get(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to read plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if cached_schema_json:
try:
return AIModelEntity.model_validate_json(cached_schema_json)
except ValidationError:
logger.warning(
"Failed to validate cached plugin model schema for model %s",
model,
exc_info=True,
)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to delete invalid plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
schema = plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
return self.model_runtime.get_model_schema(
provider=self.provider,
model_type=self.model_type,
model=model,
credentials=credentials or {},
)
if schema:
try:
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to write plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
return schema
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get customizable model schema from credentials

View File

@@ -4,9 +4,6 @@ import uuid
from collections.abc import Callable, Generator, Iterator, Sequence
from typing import Union
from pydantic import ConfigDict
from configs import dify_config
from dify_graph.model_runtime.callbacks.base_callback import Callback
from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
@@ -140,11 +137,9 @@ def _build_llm_result_from_chunks(
)
def _invoke_llm_via_plugin(
def _invoke_llm_via_runtime(
*,
tenant_id: str,
user_id: str,
plugin_id: str,
llm_model: "LargeLanguageModel",
provider: str,
model: str,
credentials: dict,
@@ -154,25 +149,19 @@ def _invoke_llm_via_plugin(
stop: Sequence[str] | None,
stream: bool,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_llm(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
return llm_model.model_runtime.invoke_llm(
provider=provider,
model=model,
credentials=credentials,
model_parameters=model_parameters,
prompt_messages=list(prompt_messages),
tools=tools,
stop=list(stop) if stop else None,
stop=stop,
stream=stream,
)
def _normalize_non_stream_plugin_result(
def _normalize_non_stream_runtime_result(
model: str,
prompt_messages: Sequence[PromptMessage],
result: Union[LLMResult, Iterator[LLMResultChunk]],
@@ -208,9 +197,6 @@ class LargeLanguageModel(AIModel):
model_type: ModelType = ModelType.LLM
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(
self,
model: str,
@@ -220,7 +206,6 @@ class LargeLanguageModel(AIModel):
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
"""
@@ -233,7 +218,6 @@ class LargeLanguageModel(AIModel):
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
@@ -245,7 +229,7 @@ class LargeLanguageModel(AIModel):
callbacks = callbacks or []
if dify_config.DEBUG:
if logger.isEnabledFor(logging.DEBUG):
callbacks.append(LoggingCallback())
# trigger before invoke callbacks
@@ -257,18 +241,15 @@ class LargeLanguageModel(AIModel):
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
try:
result = _invoke_llm_via_plugin(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
result = _invoke_llm_via_runtime(
llm_model=self,
provider=self.provider,
model=model,
credentials=credentials,
model_parameters=model_parameters,
@@ -279,7 +260,7 @@ class LargeLanguageModel(AIModel):
)
if not stream:
result = _normalize_non_stream_plugin_result(
result = _normalize_non_stream_runtime_result(
model=model, prompt_messages=prompt_messages, result=result
)
except Exception as e:
@@ -292,7 +273,6 @@ class LargeLanguageModel(AIModel):
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
@@ -309,7 +289,6 @@ class LargeLanguageModel(AIModel):
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
elif isinstance(result, LLMResult):
@@ -322,7 +301,6 @@ class LargeLanguageModel(AIModel):
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
# Following https://github.com/langgenius/dify/issues/17799,
@@ -435,22 +413,14 @@ class LargeLanguageModel(AIModel):
:param tools: tools for tool calling
:return:
"""
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_llm_num_tokens(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
tools=tools,
)
return 0
return self.model_runtime.get_llm_num_tokens(
provider=self.provider,
model_type=self.model_type,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
tools=tools,
)
def calc_response_usage(
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int

View File

@@ -1,7 +1,5 @@
import time
from pydantic import ConfigDict
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
@@ -13,30 +11,20 @@ class ModerationModel(AIModel):
model_type: ModelType = ModelType.MODERATION
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool:
def invoke(self, model: str, credentials: dict, text: str) -> bool:
"""
Invoke moderation model
:param model: model name
:param credentials: model credentials
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
self.started_at = time.perf_counter()
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_moderation(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.invoke_moderation(
provider=self.provider,
model=model,
credentials=credentials,
text=text,

View File

@@ -1,5 +1,5 @@
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
@@ -18,7 +18,6 @@ class RerankModel(AIModel):
docs: list[str],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model
@@ -29,18 +28,11 @@ class RerankModel(AIModel):
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_rerank(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.invoke_rerank(
provider=self.provider,
model=model,
credentials=credentials,
query=query,
@@ -55,11 +47,10 @@ class RerankModel(AIModel):
self,
model: str,
credentials: dict,
query: dict,
docs: list[dict],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke multimodal rerank model
@@ -69,18 +60,11 @@ class RerankModel(AIModel):
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_multimodal_rerank(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.invoke_multimodal_rerank(
provider=self.provider,
model=model,
credentials=credentials,
query=query,

View File

@@ -1,7 +1,5 @@
from typing import IO
from pydantic import ConfigDict
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
@@ -13,28 +11,18 @@ class Speech2TextModel(AIModel):
model_type: ModelType = ModelType.SPEECH2TEXT
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str:
def invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
"""
Invoke speech to text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_speech_to_text(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.invoke_speech_to_text(
provider=self.provider,
model=model,
credentials=credentials,
file=file,

View File

@@ -1,8 +1,5 @@
from pydantic import ConfigDict
from core.entities.embedding_type import EmbeddingInputType
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
@@ -13,16 +10,12 @@ class TextEmbeddingModel(AIModel):
model_type: ModelType = ModelType.TEXT_EMBEDDING
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(
self,
model: str,
credentials: dict,
texts: list[str] | None = None,
multimodel_documents: list[dict] | None = None,
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> EmbeddingResult:
"""
@@ -32,31 +25,21 @@ class TextEmbeddingModel(AIModel):
:param credentials: model credentials
:param texts: texts to embed
:param files: files to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
from core.plugin.impl.model import PluginModelClient
try:
plugin_model_manager = PluginModelClient()
if texts:
return plugin_model_manager.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.invoke_text_embedding(
provider=self.provider,
model=model,
credentials=credentials,
texts=texts,
input_type=input_type,
)
if multimodel_documents:
return plugin_model_manager.invoke_multimodal_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.invoke_multimodal_embedding(
provider=self.provider,
model=model,
credentials=credentials,
documents=multimodel_documents,
@@ -75,14 +58,8 @@ class TextEmbeddingModel(AIModel):
:param texts: texts to embed
:return:
"""
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_text_embedding_num_tokens(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.get_text_embedding_num_tokens(
provider=self.provider,
model=model,
credentials=credentials,
texts=texts,

View File

@@ -1,8 +1,6 @@
import logging
from collections.abc import Iterable
from pydantic import ConfigDict
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
@@ -16,38 +14,25 @@ class TTSModel(AIModel):
model_type: ModelType = ModelType.TTS
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
user: str | None = None,
) -> Iterable[bytes]:
"""
Invoke large language model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:param user: unique user id
:return: translated audio file
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_tts(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.invoke_tts(
provider=self.provider,
model=model,
credentials=credentials,
content_text=content_text,
@@ -65,14 +50,8 @@ class TTSModel(AIModel):
:param credentials: The credentials required to access the TTS model.
:return: A list of voices supported by the TTS model.
"""
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_tts_model_voices(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.get_tts_model_voices(
provider=self.provider,
model=model,
credentials=credentials,
language=language,

View File

@@ -1,16 +1,7 @@
from __future__ import annotations
import hashlib
import logging
from collections.abc import Sequence
from threading import Lock
from pydantic import ValidationError
from redis import RedisError
import contexts
from configs import dify_config
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
@@ -20,120 +11,64 @@ from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankM
from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel
from dify_graph.model_runtime.runtime import ModelRuntime
from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import (
ProviderCredentialSchemaValidator,
)
from extensions.ext_redis import redis_client
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
class ModelProviderFactory:
def __init__(self, tenant_id: str):
from core.plugin.impl.model import PluginModelClient
"""Factory for provider schemas and model-type instances backed by a runtime adapter."""
self.tenant_id = tenant_id
self.plugin_model_manager = PluginModelClient()
def __init__(self, model_runtime: ModelRuntime):
if model_runtime is None:
raise ValueError("model_runtime is required.")
self.model_runtime = model_runtime
def get_providers(self) -> Sequence[ProviderEntity]:
"""
Get all providers
:return: list of providers
Get all providers.
"""
# FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server
# The plugin server should return providers in the desired order
plugin_providers = self.get_plugin_model_providers()
return [provider.declaration for provider in plugin_providers]
return list(self.get_model_providers())
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
def get_model_providers(self) -> Sequence[ProviderEntity]:
"""
Get all plugin model providers
:return: list of plugin model providers
Get all model providers exposed by the runtime adapter.
"""
# check if context is set
try:
contexts.plugin_model_providers.get()
except LookupError:
contexts.plugin_model_providers.set(None)
contexts.plugin_model_providers_lock.set(Lock())
with contexts.plugin_model_providers_lock.get():
plugin_model_providers = contexts.plugin_model_providers.get()
if plugin_model_providers is not None:
return plugin_model_providers
plugin_model_providers = []
contexts.plugin_model_providers.set(plugin_model_providers)
# Fetch plugin model providers
plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
for provider in plugin_providers:
provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
plugin_model_providers.append(provider)
return plugin_model_providers
return self.model_runtime.fetch_model_providers()
def get_provider_schema(self, provider: str) -> ProviderEntity:
"""
Get provider schema
:param provider: provider name
:return: provider schema
Get provider schema.
"""
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
return plugin_model_provider_entity.declaration
return self.get_model_provider(provider=provider)
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
def get_model_provider(self, provider: str) -> ProviderEntity:
"""
Get plugin model provider
:param provider: provider name
:return: provider schema
Get provider schema.
"""
if "/" not in provider:
provider = str(ModelProviderID(provider))
# fetch plugin model providers
plugin_model_provider_entities = self.get_plugin_model_providers()
# get the provider
plugin_model_provider_entity = next(
(p for p in plugin_model_provider_entities if p.declaration.provider == provider),
None,
)
if not plugin_model_provider_entity:
provider_entity = self._resolve_provider(provider)
if provider_entity is None:
raise ValueError(f"Invalid provider: {provider}")
return plugin_model_provider_entity
return provider_entity
def provider_credentials_validate(self, *, provider: str, credentials: dict):
"""
Validate provider credentials
:param provider: provider name
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
:return:
Validate provider credentials.
"""
# fetch plugin model provider
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
provider_entity = self.get_model_provider(provider=provider)
# get provider_credential_schema and validate credentials according to the rules
provider_credential_schema = plugin_model_provider_entity.declaration.provider_credential_schema
provider_credential_schema = provider_entity.provider_credential_schema
if not provider_credential_schema:
raise ValueError(f"Provider {provider} does not have provider_credential_schema")
# validate provider credential schema
validator = ProviderCredentialSchemaValidator(provider_credential_schema)
filtered_credentials = validator.validate_and_filter(credentials)
# validate the credentials, raise exception if validation failed
self.plugin_model_manager.validate_provider_credentials(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_model_provider_entity.plugin_id,
provider=plugin_model_provider_entity.provider,
self.model_runtime.validate_provider_credentials(
provider=provider_entity.provider,
credentials=filtered_credentials,
)
@@ -141,33 +76,20 @@ class ModelProviderFactory:
def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict):
"""
Validate model credentials
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: model credentials, credentials form defined in `model_credential_schema`.
:return:
Validate model credentials.
"""
# fetch plugin model provider
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
provider_entity = self.get_model_provider(provider=provider)
# get model_credential_schema and validate credentials according to the rules
model_credential_schema = plugin_model_provider_entity.declaration.model_credential_schema
model_credential_schema = provider_entity.model_credential_schema
if not model_credential_schema:
raise ValueError(f"Provider {provider} does not have model_credential_schema")
# validate model credential schema
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
filtered_credentials = validator.validate_and_filter(credentials)
# call validate_credentials method of model type to validate credentials, raise exception if validation failed
self.plugin_model_manager.validate_model_credentials(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_model_provider_entity.plugin_id,
provider=plugin_model_provider_entity.provider,
model_type=model_type.value,
self.model_runtime.validate_model_credentials(
provider=provider_entity.provider,
model_type=model_type,
model=model,
credentials=filtered_credentials,
)
@@ -178,65 +100,16 @@ class ModelProviderFactory:
self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None
) -> AIModelEntity | None:
"""
Get model schema
Get model schema.
"""
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
cached_schema_json = None
try:
cached_schema_json = redis_client.get(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to read plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if cached_schema_json:
try:
return AIModelEntity.model_validate_json(cached_schema_json)
except ValidationError:
logger.warning(
"Failed to validate cached plugin model schema for model %s",
model,
exc_info=True,
)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to delete invalid plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
schema = self.plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
provider_entity = self.get_model_provider(provider)
return self.model_runtime.get_model_schema(
provider=provider_entity.provider,
model_type=model_type,
model=model,
credentials=credentials or {},
)
if schema:
try:
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to write plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
return schema
def get_models(
self,
*,
@@ -245,143 +118,56 @@ class ModelProviderFactory:
provider_configs: list[ProviderConfig] | None = None,
) -> list[SimpleProviderEntity]:
"""
Get all models for given model type
:param provider: provider name
:param model_type: model type
:param provider_configs: list of provider configs
:return: list of models
Get all models for given model type.
"""
provider_configs = provider_configs or []
# scan all providers
plugin_model_provider_entities = self.get_plugin_model_providers()
# traverse all model_provider_extensions
providers = []
for plugin_model_provider_entity in plugin_model_provider_entities:
# filter by provider if provider is present
if provider and plugin_model_provider_entity.declaration.provider != provider:
for provider_entity in self.get_model_providers():
if provider and not self._matches_provider(provider_entity, provider):
continue
# get provider schema
provider_schema = plugin_model_provider_entity.declaration
model_types = provider_schema.supported_model_types
if model_type:
if model_type not in model_types:
continue
model_types = [model_type]
all_model_type_models = []
for model_schema in provider_schema.models:
if model_schema.model_type != model_type:
continue
all_model_type_models.append(model_schema)
simple_provider_schema = provider_schema.to_simple_provider()
if model_type:
simple_provider_schema.models = all_model_type_models
if model_type and model_type not in provider_entity.supported_model_types:
continue
simple_provider_schema = provider_entity.to_simple_provider()
if model_type is not None:
simple_provider_schema.models = [
model_schema for model_schema in provider_entity.models if model_schema.model_type == model_type
]
providers.append(simple_provider_schema)
return providers
def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel:
"""
Get model type instance by provider name and model type
:param provider: provider name
:param model_type: model type
:return: model type instance
Get model type instance by provider name and model type.
"""
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
init_params = {
"tenant_id": self.tenant_id,
"plugin_id": plugin_id,
"provider_name": provider_name,
"plugin_model_provider": self.get_plugin_model_provider(provider),
}
provider_schema = self.get_model_provider(provider)
if model_type == ModelType.LLM:
return LargeLanguageModel.model_validate(init_params)
elif model_type == ModelType.TEXT_EMBEDDING:
return TextEmbeddingModel.model_validate(init_params)
elif model_type == ModelType.RERANK:
return RerankModel.model_validate(init_params)
elif model_type == ModelType.SPEECH2TEXT:
return Speech2TextModel.model_validate(init_params)
elif model_type == ModelType.MODERATION:
return ModerationModel.model_validate(init_params)
elif model_type == ModelType.TTS:
return TTSModel.model_validate(init_params)
return LargeLanguageModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
if model_type == ModelType.TEXT_EMBEDDING:
return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
if model_type == ModelType.RERANK:
return RerankModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
if model_type == ModelType.SPEECH2TEXT:
return Speech2TextModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
if model_type == ModelType.MODERATION:
return ModerationModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
if model_type == ModelType.TTS:
return TTSModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
raise ValueError(f"Unsupported model type: {model_type}")
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
"""
Get provider icon
:param provider: provider name
:param icon_type: icon type (icon_small or icon_small_dark)
:param lang: language (zh_Hans or en_US)
:return: provider icon
Get provider icon.
"""
# get the provider schema
provider_schema = self.get_provider_schema(provider)
provider_entity = self.get_model_provider(provider)
return self.model_runtime.get_provider_icon(provider=provider_entity.provider, icon_type=icon_type, lang=lang)
if icon_type.lower() == "icon_small":
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
def _resolve_provider(self, provider: str) -> ProviderEntity | None:
return next((item for item in self.get_model_providers() if self._matches_provider(item, provider)), None)
if lang.lower() == "zh_hans":
file_name = provider_schema.icon_small.zh_Hans
else:
file_name = provider_schema.icon_small.en_US
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
if lang.lower() == "zh_hans":
file_name = provider_schema.icon_small_dark.zh_Hans
else:
file_name = provider_schema.icon_small_dark.en_US
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")
if not file_name:
raise ValueError(f"Provider {provider} does not have icon.")
image_mime_types = {
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"png": "image/png",
"gif": "image/gif",
"bmp": "image/bmp",
"tiff": "image/tiff",
"tif": "image/tiff",
"webp": "image/webp",
"svg": "image/svg+xml",
"ico": "image/vnd.microsoft.icon",
"heif": "image/heif",
"heic": "image/heic",
}
extension = file_name.split(".")[-1]
mime_type = image_mime_types.get(extension, "image/png")
# get icon bytes from plugin asset manager
from core.plugin.impl.asset import PluginAssetManager
plugin_asset_manager = PluginAssetManager()
return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
def get_plugin_id_and_provider_name_from_provider(self, provider: str) -> tuple[str, str]:
"""
Get plugin id and provider name from provider name
:param provider: provider name
:return: plugin id and provider name
"""
provider_id = ModelProviderID(provider)
return provider_id.plugin_id, provider_id.provider_name
@staticmethod
def _matches_provider(provider_entity: ProviderEntity, provider: str) -> bool:
return provider in (provider_entity.provider, provider_entity.provider_name)

View File

@@ -0,0 +1,159 @@
from __future__ import annotations
from collections.abc import Generator, Iterable, Sequence
from typing import IO, Any, Protocol, Union, runtime_checkable
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
@runtime_checkable
class ModelRuntime(Protocol):
"""Port for provider discovery, schema lookup, and model execution.
`provider` is the model runtime's canonical provider identifier. Adapters may
derive transport-specific details from it, but those details stay outside
this boundary.
"""
def fetch_model_providers(self) -> Sequence[ProviderEntity]: ...
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: ...
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: ...
def validate_model_credentials(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> None: ...
def get_model_schema(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> AIModelEntity | None: ...
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: bool,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: ...
def get_llm_num_tokens(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: Sequence[PromptMessageTool] | None,
) -> int: ...
def invoke_text_embedding(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
texts: list[str],
input_type: EmbeddingInputType,
) -> EmbeddingResult: ...
def invoke_multimodal_embedding(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
documents: list[dict[str, Any]],
input_type: EmbeddingInputType,
) -> EmbeddingResult: ...
def get_text_embedding_num_tokens(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
texts: list[str],
) -> list[int]: ...
def invoke_rerank(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
query: str,
docs: list[str],
score_threshold: float | None,
top_n: int | None,
) -> RerankResult: ...
def invoke_multimodal_rerank(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None,
top_n: int | None,
) -> RerankResult: ...
def invoke_tts(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
content_text: str,
voice: str,
) -> Iterable[bytes]: ...
def get_tts_model_voices(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
language: str | None,
) -> Any: ...
def invoke_speech_to_text(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
file: IO[bytes],
) -> str: ...
def invoke_moderation(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
text: str,
) -> bool: ...