mirror of
https://github.com/langgenius/dify.git
synced 2026-05-30 07:00:37 -04:00
feat: extract model runtime
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -93,10 +93,14 @@ class ModelCredentialSchema(BaseModel):
|
||||
|
||||
class SimpleProviderEntity(BaseModel):
|
||||
"""
|
||||
Simple model class for provider.
|
||||
Simplified provider schema exposed to callers.
|
||||
|
||||
`provider` is the canonical runtime identifier. `provider_name` is an optional
|
||||
compatibility alias for short-name lookups and is empty when no alias exists.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
provider_name: str = ""
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
@@ -115,10 +119,15 @@ class ProviderHelpEntity(BaseModel):
|
||||
|
||||
class ProviderEntity(BaseModel):
|
||||
"""
|
||||
Model class for provider.
|
||||
Runtime-native provider schema.
|
||||
|
||||
`provider` is the canonical runtime identifier. `provider_name` is a
|
||||
compatibility alias for callers that still resolve providers by short name and
|
||||
is empty when no alias exists.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
provider_name: str = ""
|
||||
label: I18nObject
|
||||
description: I18nObject | None = None
|
||||
icon_small: I18nObject | None = None
|
||||
@@ -153,6 +162,7 @@ class ProviderEntity(BaseModel):
|
||||
"""
|
||||
return SimpleProviderEntity(
|
||||
provider=self.provider,
|
||||
provider_name=self.provider_name,
|
||||
label=self.label,
|
||||
icon_small=self.icon_small,
|
||||
supported_model_types=self.supported_model_types,
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MultimodalRerankInput(TypedDict):
|
||||
content: str
|
||||
content_type: str
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
"""
|
||||
Model class for rerank document.
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelUsage
|
||||
|
||||
|
||||
class EmbeddingInputType(StrEnum):
|
||||
"""Embedding request input variants understood by the model runtime."""
|
||||
|
||||
DOCUMENT = auto()
|
||||
QUERY = auto()
|
||||
|
||||
|
||||
class EmbeddingUsage(ModelUsage):
|
||||
"""
|
||||
Model class for embedding usage.
|
||||
|
||||
@@ -1,12 +1,5 @@
|
||||
import decimal
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
from dify_graph.model_runtime.entities.model_entities import (
|
||||
@@ -17,6 +10,7 @@ from dify_graph.model_runtime.entities.model_entities import (
|
||||
PriceInfo,
|
||||
PriceType,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from dify_graph.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
@@ -25,45 +19,61 @@ from dify_graph.model_runtime.errors.invoke import (
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from dify_graph.model_runtime.runtime import ModelRuntime
|
||||
|
||||
|
||||
class AIModel(BaseModel):
|
||||
class AIModel:
|
||||
"""
|
||||
Base class for all models.
|
||||
Runtime-facing base class for all model providers.
|
||||
|
||||
This stays a regular Python class because instances hold live collaborators
|
||||
such as the provider schema and runtime adapter rather than user input that
|
||||
benefits from Pydantic validation. Subclasses must pin ``model_type`` via a
|
||||
class attribute; the base class is not meant to be instantiated directly.
|
||||
"""
|
||||
|
||||
tenant_id: str = Field(description="Tenant ID")
|
||||
model_type: ModelType = Field(description="Model type")
|
||||
plugin_id: str = Field(description="Plugin ID")
|
||||
provider_name: str = Field(description="Provider")
|
||||
plugin_model_provider: PluginModelProviderEntity = Field(description="Plugin model provider")
|
||||
started_at: float = Field(description="Invoke start time", default=0)
|
||||
model_type: ModelType
|
||||
provider_schema: ProviderEntity
|
||||
model_runtime: ModelRuntime
|
||||
started_at: float
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
def __init__(
|
||||
self,
|
||||
provider_schema: ProviderEntity,
|
||||
model_runtime: ModelRuntime,
|
||||
*,
|
||||
started_at: float = 0,
|
||||
) -> None:
|
||||
if getattr(type(self), "model_type", None) is None:
|
||||
raise TypeError("AIModel subclasses must define model_type as a class attribute")
|
||||
|
||||
self.model_type = type(self).model_type
|
||||
self.provider_schema = provider_schema
|
||||
self.model_runtime = model_runtime
|
||||
self.started_at = started_at
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return self.provider_schema.provider
|
||||
|
||||
@property
|
||||
def provider_display_name(self) -> str:
|
||||
return self.provider_schema.label.en_US
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
Map model invoke error to unified error.
|
||||
|
||||
:return: Invoke error mapping
|
||||
The key is the error type thrown to the caller, and the value contains
|
||||
runtime-facing exception types that should be normalized to it.
|
||||
"""
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
|
||||
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [InvokeBadRequestError],
|
||||
PluginDaemonInnerError: [PluginDaemonInnerError],
|
||||
ValueError: [ValueError],
|
||||
}
|
||||
|
||||
@@ -79,15 +89,18 @@ class AIModel(BaseModel):
|
||||
if invoke_error == InvokeAuthorizationError:
|
||||
return InvokeAuthorizationError(
|
||||
description=(
|
||||
f"[{self.provider_name}] Incorrect model credentials provided, please check and try again."
|
||||
f"[{self.provider_display_name}] Incorrect model credentials provided, "
|
||||
"please check and try again."
|
||||
)
|
||||
)
|
||||
elif isinstance(invoke_error, InvokeError):
|
||||
return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}")
|
||||
return InvokeError(
|
||||
description=f"[{self.provider_display_name}] {invoke_error.description}, {str(error)}"
|
||||
)
|
||||
else:
|
||||
return error
|
||||
|
||||
return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}")
|
||||
return InvokeError(description=f"[{self.provider_display_name}] Error: {str(error)}")
|
||||
|
||||
def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo:
|
||||
"""
|
||||
@@ -144,65 +157,13 @@ class AIModel(BaseModel):
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
|
||||
|
||||
cached_schema_json = None
|
||||
try:
|
||||
cached_schema_json = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to read plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
if cached_schema_json:
|
||||
try:
|
||||
return AIModelEntity.model_validate_json(cached_schema_json)
|
||||
except ValidationError:
|
||||
logger.warning(
|
||||
"Failed to validate cached plugin model schema for model %s",
|
||||
model,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
schema = plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model_type=self.model_type.value,
|
||||
return self.model_runtime.get_model_schema(
|
||||
provider=self.provider,
|
||||
model_type=self.model_type,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
)
|
||||
|
||||
if schema:
|
||||
try:
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to write plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Get customizable model schema from credentials
|
||||
|
||||
@@ -4,9 +4,6 @@ import uuid
|
||||
from collections.abc import Callable, Generator, Iterator, Sequence
|
||||
from typing import Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from configs import dify_config
|
||||
from dify_graph.model_runtime.callbacks.base_callback import Callback
|
||||
from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
|
||||
@@ -140,11 +137,9 @@ def _build_llm_result_from_chunks(
|
||||
)
|
||||
|
||||
|
||||
def _invoke_llm_via_plugin(
|
||||
def _invoke_llm_via_runtime(
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
llm_model: "LargeLanguageModel",
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
@@ -154,25 +149,19 @@ def _invoke_llm_via_plugin(
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_llm(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
return llm_model.model_runtime.invoke_llm(
|
||||
provider=provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
model_parameters=model_parameters,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=tools,
|
||||
stop=list(stop) if stop else None,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_non_stream_plugin_result(
|
||||
def _normalize_non_stream_runtime_result(
|
||||
model: str,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
result: Union[LLMResult, Iterator[LLMResultChunk]],
|
||||
@@ -208,9 +197,6 @@ class LargeLanguageModel(AIModel):
|
||||
|
||||
model_type: ModelType = ModelType.LLM
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
@@ -220,7 +206,6 @@ class LargeLanguageModel(AIModel):
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
"""
|
||||
@@ -233,7 +218,6 @@ class LargeLanguageModel(AIModel):
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
@@ -245,7 +229,7 @@ class LargeLanguageModel(AIModel):
|
||||
|
||||
callbacks = callbacks or []
|
||||
|
||||
if dify_config.DEBUG:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
callbacks.append(LoggingCallback())
|
||||
|
||||
# trigger before invoke callbacks
|
||||
@@ -257,18 +241,15 @@ class LargeLanguageModel(AIModel):
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
|
||||
|
||||
try:
|
||||
result = _invoke_llm_via_plugin(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
result = _invoke_llm_via_runtime(
|
||||
llm_model=self,
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
model_parameters=model_parameters,
|
||||
@@ -279,7 +260,7 @@ class LargeLanguageModel(AIModel):
|
||||
)
|
||||
|
||||
if not stream:
|
||||
result = _normalize_non_stream_plugin_result(
|
||||
result = _normalize_non_stream_runtime_result(
|
||||
model=model, prompt_messages=prompt_messages, result=result
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -292,7 +273,6 @@ class LargeLanguageModel(AIModel):
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
@@ -309,7 +289,6 @@ class LargeLanguageModel(AIModel):
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
elif isinstance(result, LLMResult):
|
||||
@@ -322,7 +301,6 @@ class LargeLanguageModel(AIModel):
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
# Following https://github.com/langgenius/dify/issues/17799,
|
||||
@@ -435,22 +413,14 @@ class LargeLanguageModel(AIModel):
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_llm_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model_type=self.model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
)
|
||||
return 0
|
||||
return self.model_runtime.get_llm_num_tokens(
|
||||
provider=self.provider,
|
||||
model_type=self.model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
def calc_response_usage(
|
||||
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import time
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
@@ -13,30 +11,20 @@ class ModerationModel(AIModel):
|
||||
|
||||
model_type: ModelType = ModelType.MODERATION
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool:
|
||||
def invoke(self, model: str, credentials: dict, text: str) -> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
self.started_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.invoke_moderation(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
text=text,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
|
||||
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ class RerankModel(AIModel):
|
||||
docs: list[str],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
@@ -29,18 +28,11 @@ class RerankModel(AIModel):
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.invoke_rerank(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
@@ -55,11 +47,10 @@ class RerankModel(AIModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke multimodal rerank model
|
||||
@@ -69,18 +60,11 @@ class RerankModel(AIModel):
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_multimodal_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.invoke_multimodal_rerank(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import IO
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
@@ -13,28 +11,18 @@ class Speech2TextModel(AIModel):
|
||||
|
||||
model_type: ModelType = ModelType.SPEECH2TEXT
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str:
|
||||
def invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
|
||||
"""
|
||||
Invoke speech to text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_speech_to_text(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.invoke_speech_to_text(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
file=file,
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
@@ -13,16 +10,12 @@ class TextEmbeddingModel(AIModel):
|
||||
|
||||
model_type: ModelType = ModelType.TEXT_EMBEDDING
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str] | None = None,
|
||||
multimodel_documents: list[dict] | None = None,
|
||||
user: str | None = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
@@ -32,31 +25,21 @@ class TextEmbeddingModel(AIModel):
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param files: files to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
try:
|
||||
plugin_model_manager = PluginModelClient()
|
||||
if texts:
|
||||
return plugin_model_manager.invoke_text_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.invoke_text_embedding(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
)
|
||||
if multimodel_documents:
|
||||
return plugin_model_manager.invoke_multimodal_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.invoke_multimodal_embedding(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
documents=multimodel_documents,
|
||||
@@ -75,14 +58,8 @@ class TextEmbeddingModel(AIModel):
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_text_embedding_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.get_text_embedding_num_tokens(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
@@ -16,38 +14,25 @@ class TTSModel(AIModel):
|
||||
|
||||
model_type: ModelType = ModelType.TTS
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
tenant_id: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
user: str | None = None,
|
||||
) -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_tts(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.invoke_tts(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
@@ -65,14 +50,8 @@ class TTSModel(AIModel):
|
||||
:param credentials: The credentials required to access the TTS model.
|
||||
:return: A list of voices supported by the TTS model.
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_tts_model_voices(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.get_tts_model_voices(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
language=language,
|
||||
|
||||
@@ -1,16 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from threading import Lock
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
@@ -20,120 +11,64 @@ from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankM
|
||||
from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from dify_graph.model_runtime.runtime import ModelRuntime
|
||||
from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
||||
from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import (
|
||||
ProviderCredentialSchemaValidator,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelProviderFactory:
|
||||
def __init__(self, tenant_id: str):
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
"""Factory for provider schemas and model-type instances backed by a runtime adapter."""
|
||||
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_model_manager = PluginModelClient()
|
||||
def __init__(self, model_runtime: ModelRuntime):
|
||||
if model_runtime is None:
|
||||
raise ValueError("model_runtime is required.")
|
||||
self.model_runtime = model_runtime
|
||||
|
||||
def get_providers(self) -> Sequence[ProviderEntity]:
|
||||
"""
|
||||
Get all providers
|
||||
:return: list of providers
|
||||
Get all providers.
|
||||
"""
|
||||
# FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server
|
||||
# The plugin server should return providers in the desired order
|
||||
plugin_providers = self.get_plugin_model_providers()
|
||||
return [provider.declaration for provider in plugin_providers]
|
||||
return list(self.get_model_providers())
|
||||
|
||||
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
|
||||
def get_model_providers(self) -> Sequence[ProviderEntity]:
|
||||
"""
|
||||
Get all plugin model providers
|
||||
:return: list of plugin model providers
|
||||
Get all model providers exposed by the runtime adapter.
|
||||
"""
|
||||
# check if context is set
|
||||
try:
|
||||
contexts.plugin_model_providers.get()
|
||||
except LookupError:
|
||||
contexts.plugin_model_providers.set(None)
|
||||
contexts.plugin_model_providers_lock.set(Lock())
|
||||
|
||||
with contexts.plugin_model_providers_lock.get():
|
||||
plugin_model_providers = contexts.plugin_model_providers.get()
|
||||
if plugin_model_providers is not None:
|
||||
return plugin_model_providers
|
||||
|
||||
plugin_model_providers = []
|
||||
contexts.plugin_model_providers.set(plugin_model_providers)
|
||||
|
||||
# Fetch plugin model providers
|
||||
plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
|
||||
|
||||
for provider in plugin_providers:
|
||||
provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
|
||||
plugin_model_providers.append(provider)
|
||||
|
||||
return plugin_model_providers
|
||||
return self.model_runtime.fetch_model_providers()
|
||||
|
||||
def get_provider_schema(self, provider: str) -> ProviderEntity:
|
||||
"""
|
||||
Get provider schema
|
||||
:param provider: provider name
|
||||
:return: provider schema
|
||||
Get provider schema.
|
||||
"""
|
||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||
return plugin_model_provider_entity.declaration
|
||||
return self.get_model_provider(provider=provider)
|
||||
|
||||
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
|
||||
def get_model_provider(self, provider: str) -> ProviderEntity:
|
||||
"""
|
||||
Get plugin model provider
|
||||
:param provider: provider name
|
||||
:return: provider schema
|
||||
Get provider schema.
|
||||
"""
|
||||
if "/" not in provider:
|
||||
provider = str(ModelProviderID(provider))
|
||||
|
||||
# fetch plugin model providers
|
||||
plugin_model_provider_entities = self.get_plugin_model_providers()
|
||||
|
||||
# get the provider
|
||||
plugin_model_provider_entity = next(
|
||||
(p for p in plugin_model_provider_entities if p.declaration.provider == provider),
|
||||
None,
|
||||
)
|
||||
|
||||
if not plugin_model_provider_entity:
|
||||
provider_entity = self._resolve_provider(provider)
|
||||
if provider_entity is None:
|
||||
raise ValueError(f"Invalid provider: {provider}")
|
||||
|
||||
return plugin_model_provider_entity
|
||||
return provider_entity
|
||||
|
||||
def provider_credentials_validate(self, *, provider: str, credentials: dict):
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
:param provider: provider name
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
:return:
|
||||
Validate provider credentials.
|
||||
"""
|
||||
# fetch plugin model provider
|
||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||
provider_entity = self.get_model_provider(provider=provider)
|
||||
|
||||
# get provider_credential_schema and validate credentials according to the rules
|
||||
provider_credential_schema = plugin_model_provider_entity.declaration.provider_credential_schema
|
||||
provider_credential_schema = provider_entity.provider_credential_schema
|
||||
if not provider_credential_schema:
|
||||
raise ValueError(f"Provider {provider} does not have provider_credential_schema")
|
||||
|
||||
# validate provider credential schema
|
||||
validator = ProviderCredentialSchemaValidator(provider_credential_schema)
|
||||
filtered_credentials = validator.validate_and_filter(credentials)
|
||||
|
||||
# validate the credentials, raise exception if validation failed
|
||||
self.plugin_model_manager.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=plugin_model_provider_entity.plugin_id,
|
||||
provider=plugin_model_provider_entity.provider,
|
||||
self.model_runtime.validate_provider_credentials(
|
||||
provider=provider_entity.provider,
|
||||
credentials=filtered_credentials,
|
||||
)
|
||||
|
||||
@@ -141,33 +76,20 @@ class ModelProviderFactory:
|
||||
|
||||
def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict):
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials, credentials form defined in `model_credential_schema`.
|
||||
:return:
|
||||
Validate model credentials.
|
||||
"""
|
||||
# fetch plugin model provider
|
||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||
provider_entity = self.get_model_provider(provider=provider)
|
||||
|
||||
# get model_credential_schema and validate credentials according to the rules
|
||||
model_credential_schema = plugin_model_provider_entity.declaration.model_credential_schema
|
||||
model_credential_schema = provider_entity.model_credential_schema
|
||||
if not model_credential_schema:
|
||||
raise ValueError(f"Provider {provider} does not have model_credential_schema")
|
||||
|
||||
# validate model credential schema
|
||||
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
|
||||
filtered_credentials = validator.validate_and_filter(credentials)
|
||||
|
||||
# call validate_credentials method of model type to validate credentials, raise exception if validation failed
|
||||
self.plugin_model_manager.validate_model_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=plugin_model_provider_entity.plugin_id,
|
||||
provider=plugin_model_provider_entity.provider,
|
||||
model_type=model_type.value,
|
||||
self.model_runtime.validate_model_credentials(
|
||||
provider=provider_entity.provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=filtered_credentials,
|
||||
)
|
||||
@@ -178,65 +100,16 @@ class ModelProviderFactory:
|
||||
self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
Get model schema.
|
||||
"""
|
||||
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
|
||||
cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
|
||||
|
||||
cached_schema_json = None
|
||||
try:
|
||||
cached_schema_json = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to read plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
if cached_schema_json:
|
||||
try:
|
||||
return AIModelEntity.model_validate_json(cached_schema_json)
|
||||
except ValidationError:
|
||||
logger.warning(
|
||||
"Failed to validate cached plugin model schema for model %s",
|
||||
model,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
schema = self.plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
provider_entity = self.get_model_provider(provider)
|
||||
return self.model_runtime.get_model_schema(
|
||||
provider=provider_entity.provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
)
|
||||
|
||||
if schema:
|
||||
try:
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to write plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
def get_models(
|
||||
self,
|
||||
*,
|
||||
@@ -245,143 +118,56 @@ class ModelProviderFactory:
|
||||
provider_configs: list[ProviderConfig] | None = None,
|
||||
) -> list[SimpleProviderEntity]:
|
||||
"""
|
||||
Get all models for given model type
|
||||
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param provider_configs: list of provider configs
|
||||
:return: list of models
|
||||
Get all models for given model type.
|
||||
"""
|
||||
provider_configs = provider_configs or []
|
||||
|
||||
# scan all providers
|
||||
plugin_model_provider_entities = self.get_plugin_model_providers()
|
||||
|
||||
# traverse all model_provider_extensions
|
||||
providers = []
|
||||
for plugin_model_provider_entity in plugin_model_provider_entities:
|
||||
# filter by provider if provider is present
|
||||
if provider and plugin_model_provider_entity.declaration.provider != provider:
|
||||
for provider_entity in self.get_model_providers():
|
||||
if provider and not self._matches_provider(provider_entity, provider):
|
||||
continue
|
||||
|
||||
# get provider schema
|
||||
provider_schema = plugin_model_provider_entity.declaration
|
||||
|
||||
model_types = provider_schema.supported_model_types
|
||||
if model_type:
|
||||
if model_type not in model_types:
|
||||
continue
|
||||
|
||||
model_types = [model_type]
|
||||
|
||||
all_model_type_models = []
|
||||
for model_schema in provider_schema.models:
|
||||
if model_schema.model_type != model_type:
|
||||
continue
|
||||
|
||||
all_model_type_models.append(model_schema)
|
||||
|
||||
simple_provider_schema = provider_schema.to_simple_provider()
|
||||
if model_type:
|
||||
simple_provider_schema.models = all_model_type_models
|
||||
if model_type and model_type not in provider_entity.supported_model_types:
|
||||
continue
|
||||
|
||||
simple_provider_schema = provider_entity.to_simple_provider()
|
||||
if model_type is not None:
|
||||
simple_provider_schema.models = [
|
||||
model_schema for model_schema in provider_entity.models if model_schema.model_type == model_type
|
||||
]
|
||||
providers.append(simple_provider_schema)
|
||||
|
||||
return providers
|
||||
|
||||
def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel:
|
||||
"""
|
||||
Get model type instance by provider name and model type
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:return: model type instance
|
||||
Get model type instance by provider name and model type.
|
||||
"""
|
||||
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
|
||||
init_params = {
|
||||
"tenant_id": self.tenant_id,
|
||||
"plugin_id": plugin_id,
|
||||
"provider_name": provider_name,
|
||||
"plugin_model_provider": self.get_plugin_model_provider(provider),
|
||||
}
|
||||
provider_schema = self.get_model_provider(provider)
|
||||
|
||||
if model_type == ModelType.LLM:
|
||||
return LargeLanguageModel.model_validate(init_params)
|
||||
elif model_type == ModelType.TEXT_EMBEDDING:
|
||||
return TextEmbeddingModel.model_validate(init_params)
|
||||
elif model_type == ModelType.RERANK:
|
||||
return RerankModel.model_validate(init_params)
|
||||
elif model_type == ModelType.SPEECH2TEXT:
|
||||
return Speech2TextModel.model_validate(init_params)
|
||||
elif model_type == ModelType.MODERATION:
|
||||
return ModerationModel.model_validate(init_params)
|
||||
elif model_type == ModelType.TTS:
|
||||
return TTSModel.model_validate(init_params)
|
||||
return LargeLanguageModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
|
||||
if model_type == ModelType.TEXT_EMBEDDING:
|
||||
return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
|
||||
if model_type == ModelType.RERANK:
|
||||
return RerankModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
|
||||
if model_type == ModelType.SPEECH2TEXT:
|
||||
return Speech2TextModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
|
||||
if model_type == ModelType.MODERATION:
|
||||
return ModerationModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
|
||||
if model_type == ModelType.TTS:
|
||||
return TTSModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
|
||||
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||
"""
|
||||
Get provider icon
|
||||
:param provider: provider name
|
||||
:param icon_type: icon type (icon_small or icon_small_dark)
|
||||
:param lang: language (zh_Hans or en_US)
|
||||
:return: provider icon
|
||||
Get provider icon.
|
||||
"""
|
||||
# get the provider schema
|
||||
provider_schema = self.get_provider_schema(provider)
|
||||
provider_entity = self.get_model_provider(provider)
|
||||
return self.model_runtime.get_provider_icon(provider=provider_entity.provider, icon_type=icon_type, lang=lang)
|
||||
|
||||
if icon_type.lower() == "icon_small":
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
def _resolve_provider(self, provider: str) -> ProviderEntity | None:
|
||||
return next((item for item in self.get_model_providers() if self._matches_provider(item, provider)), None)
|
||||
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_small.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_small.en_US
|
||||
elif icon_type.lower() == "icon_small_dark":
|
||||
if not provider_schema.icon_small_dark:
|
||||
raise ValueError(f"Provider {provider} does not have small dark icon.")
|
||||
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_small_dark.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_small_dark.en_US
|
||||
else:
|
||||
raise ValueError(f"Unsupported icon type: {icon_type}.")
|
||||
|
||||
if not file_name:
|
||||
raise ValueError(f"Provider {provider} does not have icon.")
|
||||
|
||||
image_mime_types = {
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"bmp": "image/bmp",
|
||||
"tiff": "image/tiff",
|
||||
"tif": "image/tiff",
|
||||
"webp": "image/webp",
|
||||
"svg": "image/svg+xml",
|
||||
"ico": "image/vnd.microsoft.icon",
|
||||
"heif": "image/heif",
|
||||
"heic": "image/heic",
|
||||
}
|
||||
|
||||
extension = file_name.split(".")[-1]
|
||||
mime_type = image_mime_types.get(extension, "image/png")
|
||||
|
||||
# get icon bytes from plugin asset manager
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
|
||||
plugin_asset_manager = PluginAssetManager()
|
||||
return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
|
||||
|
||||
def get_plugin_id_and_provider_name_from_provider(self, provider: str) -> tuple[str, str]:
|
||||
"""
|
||||
Get plugin id and provider name from provider name
|
||||
:param provider: provider name
|
||||
:return: plugin id and provider name
|
||||
"""
|
||||
|
||||
provider_id = ModelProviderID(provider)
|
||||
return provider_id.plugin_id, provider_id.provider_name
|
||||
@staticmethod
|
||||
def _matches_provider(provider_entity: ProviderEntity, provider: str) -> bool:
|
||||
return provider in (provider_entity.provider, provider_entity.provider_name)
|
||||
|
||||
159
api/dify_graph/model_runtime/runtime.py
Normal file
159
api/dify_graph/model_runtime/runtime.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from typing import IO, Any, Protocol, Union, runtime_checkable
|
||||
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ModelRuntime(Protocol):
|
||||
"""Port for provider discovery, schema lookup, and model execution.
|
||||
|
||||
`provider` is the model runtime's canonical provider identifier. Adapters may
|
||||
derive transport-specific details from it, but those details stay outside
|
||||
this boundary.
|
||||
"""
|
||||
|
||||
def fetch_model_providers(self) -> Sequence[ProviderEntity]: ...
|
||||
|
||||
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: ...
|
||||
|
||||
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: ...
|
||||
|
||||
def validate_model_credentials(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> None: ...
|
||||
|
||||
def get_model_schema(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> AIModelEntity | None: ...
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: ...
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
) -> int: ...
|
||||
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
input_type: EmbeddingInputType,
|
||||
) -> EmbeddingResult: ...
|
||||
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
documents: list[dict[str, Any]],
|
||||
input_type: EmbeddingInputType,
|
||||
) -> EmbeddingResult: ...
|
||||
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
) -> list[int]: ...
|
||||
|
||||
def invoke_rerank(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: float | None,
|
||||
top_n: int | None,
|
||||
) -> RerankResult: ...
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None,
|
||||
top_n: int | None,
|
||||
) -> RerankResult: ...
|
||||
|
||||
def invoke_tts(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
content_text: str,
|
||||
voice: str,
|
||||
) -> Iterable[bytes]: ...
|
||||
|
||||
def get_tts_model_voices(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
language: str | None,
|
||||
) -> Any: ...
|
||||
|
||||
def invoke_speech_to_text(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
file: IO[bytes],
|
||||
) -> str: ...
|
||||
|
||||
def invoke_moderation(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
text: str,
|
||||
) -> bool: ...
|
||||
Reference in New Issue
Block a user