feat: extract model runtime

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-15 15:34:47 +08:00
parent 3d5a29462e
commit fbb74a4af9
178 changed files with 4343 additions and 2134 deletions

View File

@@ -122,7 +122,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tools=[],
stop=app_generate_entity.model_conf.stop,
stream=True,
user=self.user_id,
callbacks=[],
)

View File

@@ -96,7 +96,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tools=prompt_messages_tools,
stop=app_generate_entity.model_conf.stop,
stream=self.stream_tool_call,
user=self.user_id,
callbacks=[],
)

View File

@@ -4,7 +4,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from dify_graph.model_runtime.entities.llm_entities import LLMMode
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@@ -21,7 +21,7 @@ class ModelConfigConverter:
"""
model_config = app_config.model
provider_manager = ProviderManager()
provider_manager = create_plugin_provider_manager(tenant_id=app_config.tenant_id)
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
)

View File

@@ -2,9 +2,8 @@ from collections.abc import Mapping
from typing import Any
from core.app.app_config.entities import ModelConfigEntity
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from models.model import AppModelConfigDict
from models.provider_ids import ModelProviderID
@@ -55,7 +54,7 @@ class ModelConfigManager:
raise ValueError("model must be of object type")
# model.provider
model_provider_factory = ModelProviderFactory(tenant_id)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities]
if "provider" not in config["model"]:
@@ -71,7 +70,7 @@ class ModelConfigManager:
if "name" not in config["model"]:
raise ValueError("model.name is required")
provider_manager = ProviderManager()
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
models = provider_manager.get_configurations(tenant_id).get_models(
provider=config["model"]["provider"], model_type=ModelType.LLM
)

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.file import File, FileUploadConfig
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
@@ -15,6 +14,9 @@ if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
DIFY_RUN_CONTEXT_KEY = "_dify"
class UserFrom(StrEnum):
ACCOUNT = "account"
END_USER = "end-user"

View File

@@ -2,23 +2,34 @@ from __future__ import annotations
from typing import Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
from core.app.llm.protocols import CredentialsProvider, ModelFactory
from core.errors.error import ProviderTokenNotInitError
from core.model_manager import ModelInstance, ModelManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.nodes.llm.entities import ModelConfig
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
class DifyCredentialsProvider:
tenant_id: str
provider_manager: ProviderManager
def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None:
self.tenant_id = tenant_id
self.provider_manager = provider_manager or ProviderManager()
def __init__(
self,
*,
run_context: DifyRunContext,
provider_manager: ProviderManager | None = None,
) -> None:
self.tenant_id = run_context.tenant_id
if provider_manager is None:
provider_manager = create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
self.provider_manager = provider_manager
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
@@ -42,9 +53,21 @@ class DifyModelFactory:
tenant_id: str
model_manager: ModelManager
def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None:
self.tenant_id = tenant_id
self.model_manager = model_manager or ModelManager()
def __init__(
self,
*,
run_context: DifyRunContext,
model_manager: ModelManager | None = None,
) -> None:
self.tenant_id = run_context.tenant_id
if model_manager is None:
model_manager = ModelManager(
provider_manager=create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
)
self.model_manager = model_manager
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
return self.model_manager.get_model_instance(
@@ -55,11 +78,35 @@ class DifyModelFactory:
)
def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]:
return (
DifyCredentialsProvider(tenant_id=tenant_id),
DifyModelFactory(tenant_id=tenant_id),
def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsProvider, ModelFactory]:
"""Create LLM access adapters that share the same tenant-bound manager graph."""
provider_manager = create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
model_manager = ModelManager(provider_manager=provider_manager)
return (
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
DifyModelFactory(run_context=run_context, model_manager=model_manager),
)
def _normalize_completion_params(completion_params: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Split node-level completion params into provider parameters and stop sequences.
Workflow LLM-compatible nodes still consume runtime invocation settings from
``ModelInstance.parameters`` and ``ModelInstance.stop``. Keep the
``ModelInstance`` view and the returned config entity aligned here so callers
do not need to duplicate normalization logic.
"""
normalized_parameters = dict(completion_params)
stop = normalized_parameters.pop("stop", [])
if not isinstance(stop, list) or not all(isinstance(item, str) for item in stop):
stop = []
return normalized_parameters, stop
def fetch_model_config(
@@ -80,22 +127,18 @@ def fetch_model_config(
model_type=ModelType.LLM,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
raise ModelNotExistError(f"Model {node_data_model.name} does not exist.")
provider_model.raise_for_status()
completion_params = dict(node_data_model.completion_params)
stop = completion_params.pop("stop", [])
if not isinstance(stop, list):
stop = []
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
if model_schema is None:
raise ModelNotExistError(f"Model {node_data_model.name} schema does not exist.")
parameters, stop = _normalize_completion_params(node_data_model.completion_params)
model_instance.provider = node_data_model.provider
model_instance.model_name = node_data_model.name
model_instance.credentials = credentials
model_instance.parameters = completion_params
model_instance.parameters = parameters
model_instance.stop = tuple(stop)
return model_instance, ModelConfigWithCredentialsEntity(
@@ -103,8 +146,8 @@ def fetch_model_config(
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=provider_model_bundle,
credentials=credentials,
parameters=completion_params,
parameters=parameters,
stop=stop,
provider_model_bundle=provider_model_bundle,
)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from typing import Any, Protocol
from core.model_manager import ModelInstance
class CredentialsProvider(Protocol):
"""Workflow-layer port for loading runtime credentials for a provider/model pair."""
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
"""Return credentials for the target provider/model or raise a domain error."""
...
class ModelFactory(Protocol):
"""Workflow-layer port for creating mutable ModelInstance objects."""
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
"""Create a model instance that is ready for workflow-side hydration."""
...

View File

@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, cast, final
from typing_extensions import override
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
@@ -75,7 +76,7 @@ class LLMQuotaLayer(GraphEngineLayer):
return
try:
dify_ctx = node.require_dify_context()
dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
deduct_llm_quota(
tenant_id=dify_ctx.tenant_id,
model_instance=model_instance,

View File

@@ -25,12 +25,10 @@ class AudioTrunk:
self.status = status
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
def _invoice_tts(text_content: str, model_instance: ModelInstance, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
)
return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice)
def _process_future(
@@ -62,7 +60,7 @@ class AppGeneratorTTSPublisher:
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self.match = re.compile(r"[。.!?]")
self.model_manager = ModelManager()
self.model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id, user_id="responding_tts")
self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.TTS
)
@@ -89,7 +87,7 @@ class AppGeneratorTTSPublisher:
if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit(
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
_invoice_tts, self.msg_text, self.model_instance, self.voice
)
future_queue.put(futures_result)
break
@@ -117,9 +115,7 @@ class AppGeneratorTTSPublisher:
if len(sentence_arr) >= min(self.max_sentence, 7):
self.max_sentence += 1
text_content = "".join(sentence_arr)
futures_result = self.executor.submit(
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
)
futures_result = self.executor.submit(_invoice_tts, text_content, self.model_instance, self.voice)
future_queue.put(futures_result)
if isinstance(text_tmp, str):
self.msg_text = text_tmp

View File

@@ -1,10 +1,5 @@
from enum import StrEnum, auto
"""Compatibility wrapper for the runtime embedding input enum."""
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType
class EmbeddingInputType(StrEnum):
"""
Enum for embedding input type.
"""
DOCUMENT = auto()
QUERY = auto()
__all__ = ["EmbeddingInputType"]

View File

@@ -19,6 +19,7 @@ from core.entities.provider_entities import (
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from dify_graph.model_runtime.entities.provider_entities import (
ConfigurateMethod,
@@ -27,7 +28,6 @@ from dify_graph.model_runtime.entities.provider_entities import (
ProviderEntity,
)
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.enums import CredentialSourceType
@@ -343,7 +343,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = ModelProviderFactory(self.tenant_id)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
@@ -902,7 +902,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = ModelProviderFactory(self.tenant_id)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
@@ -1388,7 +1388,7 @@ class ProviderConfiguration(BaseModel):
:param model_type: model type
:return:
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
@@ -1397,7 +1397,7 @@ class ProviderConfiguration(BaseModel):
"""
Get model schema
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
return model_provider_factory.get_model_schema(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
@@ -1499,7 +1499,7 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
model_types: list[ModelType] = []

View File

@@ -4,10 +4,10 @@ from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_hosting_provider import hosting_configuration
from models.provider import ProviderType
@@ -41,7 +41,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
text_chunk = secrets.choice(text_chunks)
try:
model_provider_factory = ModelProviderFactory(tenant_id)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
# Get model instance of LLM
model_type_instance = model_provider_factory.get_model_type_instance(

View File

@@ -50,7 +50,10 @@ logger = logging.getLogger(__name__)
class IndexingRunner:
def __init__(self):
self.storage = storage
self.model_manager = ModelManager()
@staticmethod
def _get_model_manager(tenant_id: str) -> ModelManager:
return ModelManager.for_tenant(tenant_id=tenant_id)
def _handle_indexing_error(self, document_id: str, error: Exception) -> None:
"""Handle indexing errors by updating document status."""
@@ -291,20 +294,20 @@ class IndexingRunner:
raise ValueError("Dataset not found.")
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
embedding_model_instance = self._get_model_manager(tenant_id).get_model_instance(
tenant_id=tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
if indexing_technique == "high_quality":
embedding_model_instance = self.model_manager.get_default_model_instance(
embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
@@ -574,7 +577,7 @@ class IndexingRunner:
embedding_model_instance = None
if dataset.indexing_technique == "high_quality":
embedding_model_instance = self.model_manager.get_model_instance(
embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
@@ -766,14 +769,14 @@ class IndexingRunner:
embedding_model_instance = None
if dataset.indexing_technique == "high_quality":
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)

View File

@@ -62,7 +62,7 @@ class LLMGenerator:
prompt += query + "\n"
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -120,7 +120,7 @@ class LLMGenerator:
prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions})
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -172,7 +172,7 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt_generate)]
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
@@ -219,7 +219,7 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)]
# get model instance
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -306,7 +306,7 @@ class LLMGenerator:
remove_template_variables=False,
)
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -337,7 +337,7 @@ class LLMGenerator:
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -362,7 +362,7 @@ class LLMGenerator:
@classmethod
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -536,7 +536,7 @@ class LLMGenerator:
injected_instruction = injected_instruction.replace(CURRENT, current or "null")
if ERROR_MESSAGE in injected_instruction:
injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null")
model_instance = ModelManager().get_model_instance(
model_instance = ModelManager.for_tenant(tenant_id=tenant_id).get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,

View File

@@ -55,7 +55,6 @@ def invoke_llm_with_structured_output(
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[True],
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
@overload
@@ -70,7 +69,6 @@ def invoke_llm_with_structured_output(
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[False],
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> LLMResultWithStructuredOutput: ...
@overload
@@ -85,7 +83,6 @@ def invoke_llm_with_structured_output(
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
@@ -99,7 +96,6 @@ def invoke_llm_with_structured_output(
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
"""
@@ -113,7 +109,6 @@ def invoke_llm_with_structured_output(
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
@@ -143,7 +138,6 @@ def invoke_llm_with_structured_output(
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)

View File

@@ -7,6 +7,7 @@ from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.errors.error import ProviderTokenNotInitError
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.callbacks.base_callback import Callback
from dify_graph.model_runtime.entities.llm_entities import LLMResult
@@ -30,7 +31,7 @@ logger = logging.getLogger(__name__)
class ModelInstance:
"""
Model instance class
Model instance class.
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
@@ -110,7 +111,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[True] = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Generator: ...
@@ -122,7 +122,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[False] = False,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> LLMResult: ...
@@ -134,7 +133,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Union[LLMResult, Generator]: ...
@@ -145,7 +143,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Union[LLMResult, Generator]:
"""
@@ -156,7 +153,6 @@ class ModelInstance:
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
@@ -173,7 +169,6 @@ class ModelInstance:
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
),
)
@@ -202,13 +197,12 @@ class ModelInstance:
)
def invoke_text_embedding(
self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
self, texts: list[str], input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
) -> EmbeddingResult:
"""
Invoke large language model
:param texts: texts to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
@@ -221,7 +215,6 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
texts=texts,
user=user,
input_type=input_type,
),
)
@@ -229,14 +222,12 @@ class ModelInstance:
def invoke_multimodal_embedding(
self,
multimodel_documents: list[dict],
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> EmbeddingResult:
"""
Invoke large language model
:param multimodel_documents: multimodel documents to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
@@ -249,7 +240,6 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
user=user,
input_type=input_type,
),
)
@@ -279,7 +269,6 @@ class ModelInstance:
docs: list[str],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model
@@ -288,7 +277,6 @@ class ModelInstance:
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
@@ -303,7 +291,6 @@ class ModelInstance:
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user,
),
)
@@ -313,7 +300,6 @@ class ModelInstance:
docs: list[dict],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model
@@ -322,7 +308,6 @@ class ModelInstance:
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
@@ -337,16 +322,14 @@ class ModelInstance:
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user,
),
)
def invoke_moderation(self, text: str, user: str | None = None) -> bool:
def invoke_moderation(self, text: str) -> bool:
"""
Invoke moderation model
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
if not isinstance(self.model_type_instance, ModerationModel):
@@ -358,16 +341,14 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
text=text,
user=user,
),
)
def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str:
def invoke_speech2text(self, file: IO[bytes]) -> str:
"""
Invoke large language model
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
@@ -379,18 +360,15 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
file=file,
user=user,
),
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]:
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
"""
Invoke large language tts model
:param content_text: text content to be translated
:param tenant_id: user tenant id
:param voice: model timbre
:param user: unique user id
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, TTSModel):
@@ -402,8 +380,6 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
user=user,
tenant_id=tenant_id,
voice=voice,
),
)
@@ -477,10 +453,20 @@ class ModelInstance:
class ModelManager:
def __init__(self):
self._provider_manager = ProviderManager()
def __init__(self, provider_manager: ProviderManager):
self._provider_manager = provider_manager
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
@classmethod
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
return cls(provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id))
def get_model_instance(
self,
tenant_id: str,
provider: str,
model_type: ModelType,
model: str,
) -> ModelInstance:
"""
Get model instance
:param tenant_id: tenant id
@@ -496,7 +482,8 @@ class ModelManager:
tenant_id=tenant_id, provider=provider, model_type=model_type
)
return ModelInstance(provider_model_bundle, model)
model_instance = ModelInstance(provider_model_bundle, model)
return model_instance
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
"""

View File

@@ -50,7 +50,7 @@ class OpenAIModeration(Moderation):
def _is_violated(self, inputs: dict):
text = "\n".join(str(inputs.values()))
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest"
)

View File

@@ -30,10 +30,27 @@ from dify_graph.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Tenant
class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
@staticmethod
def _get_bound_model_instance(
*,
tenant_id: str,
user_id: str | None,
provider: str,
model_type: ModelType,
model: str,
):
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
tenant_id=tenant_id,
provider=provider,
model_type=model_type,
model=model,
)
@classmethod
def invoke_llm(
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
@@ -41,8 +58,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke llm
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -55,7 +73,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
tools=payload.tools,
stop=payload.stop,
stream=True if payload.stream is None else payload.stream,
user=user_id,
)
if isinstance(response, Generator):
@@ -94,8 +111,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke llm with structured output
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -115,7 +133,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
tools=payload.tools,
stop=payload.stop,
stream=True if payload.stream is None else payload.stream,
user=user_id,
model_parameters=payload.completion_params,
)
@@ -156,18 +173,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke text embedding
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_text_embedding(
texts=payload.texts,
user=user_id,
)
response = model_instance.invoke_text_embedding(texts=payload.texts)
return response
@@ -176,8 +191,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke rerank
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -189,7 +205,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
docs=payload.docs,
score_threshold=payload.score_threshold,
top_n=payload.top_n,
user=user_id,
)
return response
@@ -199,20 +214,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke tts
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_tts(
content_text=payload.content_text,
tenant_id=tenant.id,
voice=payload.voice,
user=user_id,
)
response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice)
def handle() -> Generator[dict, None, None]:
for chunk in response:
@@ -225,8 +236,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke speech2text
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -238,10 +250,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
temp.flush()
temp.seek(0)
response = model_instance.invoke_speech2text(
file=temp,
user=user_id,
)
response = model_instance.invoke_speech2text(file=temp)
return {
"result": response,
@@ -252,18 +261,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke moderation
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_moderation(
text=payload.text,
user=user_id,
)
response = model_instance.invoke_moderation(text=payload.text)
return {
"result": response,

View File

@@ -1,6 +1,6 @@
import binascii
from collections.abc import Generator, Sequence
from typing import IO
from typing import IO, Any
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
@@ -16,12 +16,19 @@ from core.plugin.impl.base import BasePluginClient
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
class PluginModelClient(BasePluginClient):
@staticmethod
def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]:
payload: dict[str, Any] = {"data": data}
if user_id is not None:
payload["user_id"] = user_id
return payload
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
"""
Fetch model providers for the given tenant.
@@ -37,7 +44,7 @@ class PluginModelClient(BasePluginClient):
def get_model_schema(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model_type: str,
@@ -51,15 +58,15 @@ class PluginModelClient(BasePluginClient):
"POST",
f"plugin/{tenant_id}/dispatch/model/schema",
PluginModelSchemaEntity,
data={
"user_id": user_id,
"data": {
data=self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": model_type,
"model": model,
"credentials": credentials,
},
},
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
@@ -72,7 +79,7 @@ class PluginModelClient(BasePluginClient):
return None
def validate_provider_credentials(
self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict
) -> bool:
"""
validate the credentials of the provider
@@ -81,13 +88,13 @@ class PluginModelClient(BasePluginClient):
"POST",
f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
data=self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"credentials": credentials,
},
},
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
@@ -105,7 +112,7 @@ class PluginModelClient(BasePluginClient):
def validate_model_credentials(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model_type: str,
@@ -119,15 +126,15 @@ class PluginModelClient(BasePluginClient):
"POST",
f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
data=self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": model_type,
"model": model,
"credentials": credentials,
},
},
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
@@ -145,7 +152,7 @@ class PluginModelClient(BasePluginClient):
def invoke_llm(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -164,9 +171,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
type_=LLMResultChunk,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "llm",
"model": model,
@@ -177,7 +184,7 @@ class PluginModelClient(BasePluginClient):
"stop": stop,
"stream": stream,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -193,7 +200,7 @@ class PluginModelClient(BasePluginClient):
def get_llm_num_tokens(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model_type: str,
@@ -210,9 +217,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
type_=PluginLLMNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": model_type,
"model": model,
@@ -220,7 +227,7 @@ class PluginModelClient(BasePluginClient):
"prompt_messages": prompt_messages,
"tools": tools,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -236,7 +243,7 @@ class PluginModelClient(BasePluginClient):
def invoke_text_embedding(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -252,9 +259,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
type_=EmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "text-embedding",
"model": model,
@@ -262,7 +269,7 @@ class PluginModelClient(BasePluginClient):
"texts": texts,
"input_type": input_type,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -278,7 +285,7 @@ class PluginModelClient(BasePluginClient):
def invoke_multimodal_embedding(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -294,9 +301,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke",
type_=EmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "text-embedding",
"model": model,
@@ -304,7 +311,7 @@ class PluginModelClient(BasePluginClient):
"documents": documents,
"input_type": input_type,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -320,7 +327,7 @@ class PluginModelClient(BasePluginClient):
def get_text_embedding_num_tokens(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -335,16 +342,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
type_=PluginTextEmbeddingNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "text-embedding",
"model": model,
"credentials": credentials,
"texts": texts,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -360,7 +367,7 @@ class PluginModelClient(BasePluginClient):
def invoke_rerank(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -378,9 +385,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
type_=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "rerank",
"model": model,
@@ -390,7 +397,7 @@ class PluginModelClient(BasePluginClient):
"score_threshold": score_threshold,
"top_n": top_n,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -406,13 +413,13 @@ class PluginModelClient(BasePluginClient):
def invoke_multimodal_rerank(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
query: dict,
docs: list[dict],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None = None,
top_n: int | None = None,
) -> RerankResult:
@@ -424,9 +431,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke",
type_=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "rerank",
"model": model,
@@ -436,7 +443,7 @@ class PluginModelClient(BasePluginClient):
"score_threshold": score_threshold,
"top_n": top_n,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -451,7 +458,7 @@ class PluginModelClient(BasePluginClient):
def invoke_tts(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -467,9 +474,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "tts",
"model": model,
@@ -478,7 +485,7 @@ class PluginModelClient(BasePluginClient):
"content_text": content_text,
"voice": voice,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -496,7 +503,7 @@ class PluginModelClient(BasePluginClient):
def get_tts_model_voices(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -511,16 +518,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
type_=PluginVoicesResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "tts",
"model": model,
"credentials": credentials,
"language": language,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -540,7 +547,7 @@ class PluginModelClient(BasePluginClient):
def invoke_speech_to_text(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -555,16 +562,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "speech2text",
"model": model,
"credentials": credentials,
"file": binascii.hexlify(file.read()).decode(),
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -580,7 +587,7 @@ class PluginModelClient(BasePluginClient):
def invoke_moderation(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -595,16 +602,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
type_=PluginBasicBooleanResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "moderation",
"model": model,
"credentials": credentials,
"text": text,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,

View File

@@ -0,0 +1,490 @@
from __future__ import annotations
import hashlib
import logging
from collections.abc import Generator, Iterable, Sequence
from threading import Lock
from typing import IO, Any, Union
from pydantic import ValidationError
from redis import RedisError
from configs import dify_config
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.model import PluginModelClient
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
from dify_graph.model_runtime.runtime import ModelRuntime
from extensions.ext_redis import redis_client
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
class PluginModelRuntime(ModelRuntime):
"""Plugin-backed runtime adapter bound to tenant context and a default user."""
tenant_id: str
user_id: str | None
client: PluginModelClient
_provider_entities: tuple[ProviderEntity, ...] | None
_provider_entities_lock: Lock
def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None:
if client is None:
raise ValueError("client is required.")
self.tenant_id = tenant_id
self.user_id = user_id
self.client = client
self._provider_entities = None
self._provider_entities_lock = Lock()
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
if self._provider_entities is not None:
return self._provider_entities
with self._provider_entities_lock:
if self._provider_entities is None:
self._provider_entities = tuple(
self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id)
)
return self._provider_entities
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
provider_schema = self._get_provider_schema(provider)
if icon_type.lower() == "icon_small":
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
file_name = (
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
)
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
file_name = (
provider_schema.icon_small_dark.zh_Hans
if lang.lower() == "zh_hans"
else provider_schema.icon_small_dark.en_US
)
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")
if not file_name:
raise ValueError(f"Provider {provider} does not have icon.")
image_mime_types = {
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"png": "image/png",
"gif": "image/gif",
"bmp": "image/bmp",
"tiff": "image/tiff",
"tif": "image/tiff",
"webp": "image/webp",
"svg": "image/svg+xml",
"ico": "image/vnd.microsoft.icon",
"heif": "image/heif",
"heic": "image/heic",
}
extension = file_name.split(".")[-1]
mime_type = image_mime_types.get(extension, "image/png")
return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None:
plugin_id, provider_name = self._split_provider(provider)
self.client.validate_provider_credentials(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
credentials=credentials,
)
def validate_model_credentials(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> None:
plugin_id, provider_name = self._split_provider(provider)
self.client.validate_model_credentials(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
)
def get_model_schema(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> AIModelEntity | None:
cache_key = self._get_schema_cache_key(
provider=provider,
model_type=model_type,
model=model,
credentials=credentials,
)
cached_schema_json = None
try:
cached_schema_json = redis_client.get(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to read plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if cached_schema_json:
try:
return AIModelEntity.model_validate_json(cached_schema_json)
except ValidationError:
logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to delete invalid plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
plugin_id, provider_name = self._split_provider(provider)
schema = self.client.get_model_schema(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
)
if schema:
try:
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to write plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
return schema
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: bool,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_llm(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
model_parameters=model_parameters,
prompt_messages=list(prompt_messages),
tools=tools,
stop=list(stop) if stop else None,
stream=stream,
)
def get_llm_num_tokens(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: Sequence[PromptMessageTool] | None,
) -> int:
if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
return 0
plugin_id, provider_name = self._split_provider(provider)
return self.client.get_llm_num_tokens(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
prompt_messages=list(prompt_messages),
tools=list(tools) if tools else None,
)
def invoke_text_embedding(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
texts: list[str],
input_type: EmbeddingInputType,
) -> EmbeddingResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
texts=texts,
input_type=input_type,
)
def invoke_multimodal_embedding(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
documents: list[dict[str, Any]],
input_type: EmbeddingInputType,
) -> EmbeddingResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_multimodal_embedding(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
documents=documents,
input_type=input_type,
)
def get_text_embedding_num_tokens(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
texts: list[str],
) -> list[int]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.get_text_embedding_num_tokens(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
texts=texts,
)
def invoke_rerank(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
query: str,
docs: list[str],
score_threshold: float | None,
top_n: int | None,
) -> RerankResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_rerank(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_multimodal_rerank(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None,
top_n: int | None,
) -> RerankResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_multimodal_rerank(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_tts(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
content_text: str,
voice: str,
) -> Iterable[bytes]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_tts(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
content_text=content_text,
voice=voice,
)
def get_tts_model_voices(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
language: str | None,
) -> Any:
plugin_id, provider_name = self._split_provider(provider)
return self.client.get_tts_model_voices(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
language=language,
)
def invoke_speech_to_text(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
file: IO[bytes],
) -> str:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_speech_to_text(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
file=file,
)
def invoke_moderation(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
text: str,
) -> bool:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_moderation(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
text=text,
)
def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str:
"""
Expose a bare provider alias only for the canonical provider mapping.
Multiple plugins can publish the same short provider slug. If every
provider entity keeps that slug in ``provider_name``, callers that still
resolve by short name become order-dependent. Restrict the alias to the
provider selected by ``ModelProviderID`` so legacy short-name lookups
remain deterministic while the runtime surface stays canonical.
"""
try:
canonical_provider_id = ModelProviderID(provider.provider)
except ValueError:
return ""
if canonical_provider_id.plugin_id != provider.plugin_id:
return ""
if canonical_provider_id.provider_name != provider.provider:
return ""
return provider.provider
def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity:
declaration = provider.declaration.model_copy(deep=True)
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
declaration.provider_name = self._get_provider_short_name_alias(provider)
return declaration
def _get_provider_schema(self, provider: str) -> ProviderEntity:
providers = self.fetch_model_providers()
provider_entity = next((item for item in providers if item.provider == provider), None)
if provider_entity is None:
provider_entity = next((item for item in providers if provider == item.provider_name), None)
if provider_entity is None:
raise ValueError(f"Invalid provider: {provider}")
return provider_entity
def _get_schema_cache_key(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> str:
cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}"
sorted_credentials = sorted(credentials.items()) if credentials else []
return cache_key + ":".join(
[hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials]
)
def _split_provider(self, provider: str) -> tuple[str, str]:
provider_id = ModelProviderID(provider)
return provider_id.plugin_id, provider_id.provider_name

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from core.plugin.impl.model import PluginModelClient
if TYPE_CHECKING:
from core.model_manager import ModelManager
from core.plugin.impl.model_runtime import PluginModelRuntime
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -> PluginModelRuntime:
"""Create a plugin runtime with its client dependency fully composed."""
from core.plugin.impl.model_runtime import PluginModelRuntime
return PluginModelRuntime(
tenant_id=tenant_id,
user_id=user_id,
client=PluginModelClient(),
)
def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None = None) -> ModelProviderFactory:
"""Create a tenant-bound model provider factory for service flows."""
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
return ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id))
def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager:
"""Create a tenant-bound provider manager for service flows."""
from core.provider_manager import ProviderManager
return ProviderManager(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id))
def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager:
"""Create a tenant-bound model manager for service flows."""
from core.model_manager import ModelManager
return ModelManager(
provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id),
)

View File

@@ -1,50 +1,7 @@
from typing import Literal
from dify_graph.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from pydantic import BaseModel
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
class ChatModelMessage(BaseModel):
"""
Chat Message.
"""
text: str
role: PromptMessageRole
edition_type: Literal["basic", "jinja2"] | None = None
class CompletionModelPromptTemplate(BaseModel):
"""
Completion Model Prompt Template.
"""
text: str
edition_type: Literal["basic", "jinja2"] | None = None
class MemoryConfig(BaseModel):
"""
Memory Config.
"""
class RolePrefix(BaseModel):
"""
Role Prefix.
"""
user: str
assistant: str
class WindowConfig(BaseModel):
"""
Window Config.
"""
enabled: bool
size: int | None = None
role_prefix: RolePrefix | None = None
window: WindowConfig
query_prompt_template: str | None = None
__all__ = [
"ChatModelMessage",
"CompletionModelPromptTemplate",
"MemoryConfig",
]

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
import contextlib
import json
from collections import defaultdict
from collections.abc import Sequence
from json import JSONDecodeError
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
@@ -53,15 +55,22 @@ from models.provider import (
from models.provider_ids import ModelProviderID
from services.feature_service import FeatureService
if TYPE_CHECKING:
from dify_graph.model_runtime.runtime import ModelRuntime
class ProviderManager:
"""
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
ProviderManager manages tenant-scoped model provider configuration.
The runtime adapter is injected by the composition layer so this class stays
focused on configuration assembly instead of constructing plugin runtimes.
"""
def __init__(self):
def __init__(self, model_runtime: ModelRuntime):
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
self._model_runtime = model_runtime
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
"""
@@ -127,7 +136,7 @@ class ProviderManager:
)
# Get all provider entities
model_provider_factory = ModelProviderFactory(tenant_id)
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
provider_entities = model_provider_factory.get_providers()
# Get All preferred provider types of the workspace
@@ -321,7 +330,7 @@ class ProviderManager:
if not default_model:
return None
model_provider_factory = ModelProviderFactory(tenant_id)
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name)
return DefaultModelEntity(

View File

@@ -106,9 +106,9 @@ class DataPostProcessor:
) -> ModelInstance | None:
if reranking_model:
try:
model_manager = ModelManager()
reranking_provider_name = reranking_model["reranking_provider_name"]
reranking_model_name = reranking_model["reranking_model_name"]
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
reranking_provider_name = reranking_model.get("reranking_provider_name")
reranking_model_name = reranking_model.get("reranking_model_name")
if not reranking_provider_name or not reranking_model_name:
return None
rerank_model_instance = model_manager.get_model_instance(

View File

@@ -328,7 +328,7 @@ class RetrievalService:
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
if dataset.is_multimodal:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id)
is_support_vision = model_manager.check_model_support_vision(
tenant_id=dataset.tenant_id,
provider=reranking_model["reranking_provider_name"],

View File

@@ -303,7 +303,7 @@ class Vector:
redis_client.delete(collection_exist_cache_key)
def _get_embeddings(self) -> Embeddings:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,

View File

@@ -72,7 +72,7 @@ class DatasetDocumentStore:
max_position = 0
embedding_model = None
if self._dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,

View File

@@ -8,7 +8,7 @@ from sqlalchemy.exc import IntegrityError
from configs import dify_config
from core.entities.embedding_type import EmbeddingInputType
from core.model_manager import ModelInstance
from core.model_manager import ModelInstance, ModelManager
from core.rag.embedding.embedding_base import Embeddings
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
@@ -22,8 +22,20 @@ logger = logging.getLogger(__name__)
class CacheEmbedding(Embeddings):
def __init__(self, model_instance: ModelInstance, user: str | None = None):
self._model_instance = model_instance
self._user = user
self._model_instance = self._bind_model_instance(model_instance, user)
@staticmethod
def _bind_model_instance(model_instance: ModelInstance, user: str | None) -> ModelInstance:
if user is None:
return model_instance
tenant_id = model_instance.provider_model_bundle.configuration.tenant_id
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user).get_model_instance(
tenant_id=tenant_id,
provider=model_instance.provider,
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
)
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs in batches of 10."""
@@ -65,7 +77,7 @@ class CacheEmbedding(Embeddings):
batch_texts = embedding_queue_texts[i : i + max_chunks]
embedding_result = self._model_instance.invoke_text_embedding(
texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
texts=batch_texts, input_type=EmbeddingInputType.DOCUMENT
)
for vector in embedding_result.embeddings:
@@ -147,7 +159,6 @@ class CacheEmbedding(Embeddings):
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=batch_multimodel_documents,
user=self._user,
input_type=EmbeddingInputType.DOCUMENT,
)
@@ -202,7 +213,7 @@ class CacheEmbedding(Embeddings):
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_text_embedding(
texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
texts=[text], input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]
@@ -245,7 +256,7 @@ class CacheEmbedding(Embeddings):
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY
multimodel_documents=[multimodel_document], input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]

View File

@@ -12,7 +12,7 @@ from core.app.llm import deduct_llm_quota
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.keyword.keyword_factory import Keyword
@@ -410,7 +410,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
# If default prompt doesn't have {language} placeholder, use it as-is
pass
provider_manager = ProviderManager()
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id, model_provider_name, ModelType.LLM
)

View File

@@ -16,6 +16,18 @@ class RerankModelRunner(BaseRerankRunner):
def __init__(self, rerank_model_instance: ModelInstance):
self.rerank_model_instance = rerank_model_instance
def _get_invoke_model_instance(self, user: str | None) -> ModelInstance:
if user is None:
return self.rerank_model_instance
tenant_id = self.rerank_model_instance.provider_model_bundle.configuration.tenant_id
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user).get_model_instance(
tenant_id=tenant_id,
provider=self.rerank_model_instance.provider,
model_type=ModelType.RERANK,
model=self.rerank_model_instance.model_name,
)
def run(
self,
query: str,
@@ -34,7 +46,9 @@ class RerankModelRunner(BaseRerankRunner):
:param user: unique user id if needed
:return:
"""
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id
)
is_support_vision = model_manager.check_model_support_vision(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
provider=self.rerank_model_instance.provider,
@@ -102,8 +116,8 @@ class RerankModelRunner(BaseRerankRunner):
docs.append(document.page_content)
unique_documents.append(document)
rerank_result = self.rerank_model_instance.invoke_rerank(
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
rerank_result = self._get_invoke_model_instance(user).invoke_rerank(
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n
)
return rerank_result, unique_documents
@@ -180,8 +194,8 @@ class RerankModelRunner(BaseRerankRunner):
"content": file_query,
"content_type": DocType.IMAGE,
}
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
rerank_result = self._get_invoke_model_instance(user).invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
)
return rerank_result, unique_documents
else:

View File

@@ -163,7 +163,7 @@ class WeightRerankRunner(BaseRerankRunner):
"""
query_vector_scores = []
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=tenant_id,

View File

@@ -160,7 +160,7 @@ class DatasetRetrieval:
if request.model_provider is None or request.model_name is None or request.query is None:
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=request.tenant_id,
model_type=ModelType.LLM,
@@ -387,7 +387,7 @@ class DatasetRetrieval:
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
)
@@ -1411,7 +1411,7 @@ class DatasetRetrieval:
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config, user_id=user_id)
# fetch prompt messages
prompt_messages, stop = self._get_prompt_template(
@@ -1430,7 +1430,6 @@ class DatasetRetrieval:
model_parameters=model_config.parameters,
stop=stop,
stream=True,
user=user_id,
),
)
@@ -1533,7 +1532,7 @@ class DatasetRetrieval:
return filters
def _fetch_model_config(
self, tenant_id: str, model: ModelConfig
self, tenant_id: str, model: ModelConfig, user_id: str | None = None
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
@@ -1543,7 +1542,7 @@ class DatasetRetrieval:
model_name = model.name
provider_name = model.provider
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)

View File

@@ -3,13 +3,14 @@ from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.llm import deduct_llm_quota
from core.model_manager import ModelInstance
from core.model_manager import ModelInstance, ModelManager
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import ModelType
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
@@ -150,19 +151,24 @@ class ReactMultiDatasetRouter:
:param stop: stop
:return:
"""
invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm(
bound_model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
tenant_id=tenant_id,
provider=model_instance.provider,
model_type=ModelType.LLM,
model=model_instance.model_name,
)
invoke_result: Generator[LLMResult, None, None] = bound_model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=completion_param,
stop=stop,
stream=True,
user=user_id,
)
# handle invoke result
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
# deduct quota
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
deduct_llm_quota(tenant_id=tenant_id, model_instance=bound_model_instance, usage=usage)
return text, usage

View File

@@ -201,8 +201,10 @@ class HumanInputFormRepositoryImpl:
self,
*,
tenant_id: str,
app_id: str | None = None,
):
self._tenant_id = tenant_id
self._app_id = app_id
def _delivery_method_to_model(
self,
@@ -340,6 +342,9 @@ class HumanInputFormRepositoryImpl:
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
form_config: HumanInputNodeData = params.form_config
app_id = params.app_id or self._app_id
if not app_id:
raise ValueError("app_id is required to create a human input form")
with session_factory.create_session() as session, session.begin():
# Generate unique form ID
@@ -359,7 +364,7 @@ class HumanInputFormRepositoryImpl:
form_model = HumanInputForm(
id=form_id,
tenant_id=self._tenant_id,
app_id=params.app_id,
app_id=app_id,
workflow_run_id=params.workflow_execution_id,
form_kind=params.form_kind,
node_id=params.node_id,

View File

@@ -22,6 +22,9 @@ class ASRTool(BuiltinTool):
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
if not self.runtime:
raise ValueError("Runtime is required")
runtime = self.runtime
file = tool_parameters.get("audio_file")
if file.type != FileType.AUDIO: # type: ignore
yield self.create_text_message("not a valid audio file")
@@ -29,20 +32,19 @@ class ASRTool(BuiltinTool):
audio_binary = io.BytesIO(download(file)) # type: ignore
audio_binary.name = "temp.mp3"
provider, model = tool_parameters.get("model").split("#") # type: ignore
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id)
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id,
tenant_id=runtime.tenant_id,
provider=provider,
model_type=ModelType.SPEECH2TEXT,
model=model,
)
text = model_instance.invoke_speech2text(
file=audio_binary,
user=user_id,
)
text = model_instance.invoke_speech2text(file=audio_binary)
yield self.create_text_message(text)
def get_available_models(self) -> list[tuple[str, str]]:
if not self.runtime:
raise ValueError("Runtime is required")
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(
tenant_id=self.runtime.tenant_id, model_type="speech2text"

View File

@@ -20,13 +20,14 @@ class TTSTool(BuiltinTool):
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
provider, model = tool_parameters.get("model").split("#") # type: ignore
voice = tool_parameters.get(f"voice#{provider}#{model}")
model_manager = ModelManager()
if not self.runtime:
raise ValueError("Runtime is required")
runtime = self.runtime
provider, model = tool_parameters.get("model").split("#") # type: ignore
voice = tool_parameters.get(f"voice#{provider}#{model}")
model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id)
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id or "",
tenant_id=runtime.tenant_id or "",
provider=provider,
model_type=ModelType.TTS,
model=model,
@@ -39,12 +40,7 @@ class TTSTool(BuiltinTool):
raise ValueError("Sorry, no voice available.")
else:
raise ValueError("Sorry, no voice available.")
tts = model_instance.invoke_tts(
content_text=tool_parameters.get("text"), # type: ignore
user=user_id,
tenant_id=self.runtime.tenant_id,
voice=voice,
)
tts = model_instance.invoke_tts(content_text=tool_parameters.get("text"), voice=voice) # type: ignore[arg-type]
buffer = io.BytesIO()
for chunk in tts:
buffer.write(chunk)

View File

@@ -65,7 +65,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
for thread in threads:
thread.join()
# do rerank for searched documents
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
rerank_model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=self.reranking_provider_name,

View File

@@ -37,7 +37,7 @@ class ModelInvocationUtils:
"""
get max llm context tokens of the model
"""
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -65,7 +65,7 @@ class ModelInvocationUtils:
"""
# get model instance
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
if not model_instance:
@@ -92,7 +92,7 @@ class ModelInvocationUtils:
"""
# get model manager
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
# get model instance
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
@@ -136,7 +136,6 @@ class ModelInvocationUtils:
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
)
except InvokeRateLimitError as e:

View File

@@ -9,8 +9,8 @@ from sqlalchemy.orm import Session
from typing_extensions import override
from configs import dify_config
from core.app.entities.app_invoke_entities import DifyRunContext
from core.app.llm.model_access import build_dify_model_access
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.llm.model_access import build_dify_model_access, fetch_model_config
from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
@@ -20,8 +20,17 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.tools.tool_file_manager import ToolFileManager
from core.trigger.constants import TRIGGER_NODE_TYPES
from core.workflow.node_runtime import (
DifyFileReferenceFactory,
DifyHumanInputNodeRuntime,
DifyPreparedLLM,
DifyPromptMessageSerializer,
DifyRetrieverAttachmentLoader,
DifyToolFileManager,
DifyToolNodeRuntime,
build_dify_llm_file_saver,
)
from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer
from core.workflow.nodes.agent.plugin_strategy_adapter import (
PluginAgentStrategyPresentationProvider,
@@ -30,11 +39,9 @@ from core.workflow.nodes.agent.plugin_strategy_adapter import (
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey
from dify_graph.file.file_manager import file_manager
from dify_graph.graph.graph import NodeFactory
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.nodes.base.node import Node
@@ -44,13 +51,10 @@ from dify_graph.nodes.code.limits import CodeNodeLimits
from dify_graph.nodes.document_extractor import UnstructuredApiConfig
from dify_graph.nodes.http_request import build_http_request_config
from dify_graph.nodes.llm.entities import LLMNodeData
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from dify_graph.nodes.llm.protocols import TemplateRenderer
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
from dify_graph.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
)
from dify_graph.template_rendering import CodeExecutorJinja2TemplateRenderer
from dify_graph.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
@@ -264,11 +268,22 @@ class DifyNodeFactory(NodeFactory):
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor)
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor)
self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
self._http_request_http_client = ssrf_proxy
self._http_request_tool_file_manager_factory = ToolFileManager
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(self._dify_context)
self._file_reference_factory = DifyFileReferenceFactory(self._dify_context)
self._prompt_message_serializer = DifyPromptMessageSerializer()
self._retriever_attachment_loader = DifyRetrieverAttachmentLoader(
file_reference_factory=self._file_reference_factory,
)
self._llm_file_saver = build_dify_llm_file_saver(
run_context=self._dify_context,
http_client=self._http_request_http_client,
)
self._human_input_runtime = DifyHumanInputNodeRuntime(self._dify_context)
self._tool_runtime = DifyToolNodeRuntime(self._dify_context)
self._http_request_file_manager = file_manager
self._document_extractor_unstructured_api_config = UnstructuredApiConfig(
api_url=dify_config.UNSTRUCTURED_API_URL,
@@ -284,7 +299,7 @@ class DifyNodeFactory(NodeFactory):
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
)
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id)
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context)
self._agent_strategy_resolver = PluginAgentStrategyResolver()
self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider()
self._agent_runtime_support = AgentRuntimeSupport()
@@ -321,22 +336,32 @@ class DifyNodeFactory(NodeFactory):
"code_limits": self._code_limits,
},
BuiltinNodeTypes.TEMPLATE_TRANSFORM: lambda: {
"template_renderer": self._template_renderer,
"jinja2_template_renderer": self._jinja2_template_renderer,
"max_output_length": self._template_transform_max_output_length,
},
BuiltinNodeTypes.HTTP_REQUEST: lambda: {
"http_request_config": self._http_request_config,
"http_client": self._http_request_http_client,
"tool_file_manager_factory": self._http_request_tool_file_manager_factory,
"tool_file_manager_factory": self._bound_tool_file_manager_factory,
"file_manager": self._http_request_file_manager,
"file_reference_factory": self._file_reference_factory,
},
BuiltinNodeTypes.HUMAN_INPUT: lambda: {
"form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
"form_repository": HumanInputFormRepositoryImpl(
tenant_id=self._dify_context.tenant_id,
app_id=self._dify_context.app_id,
),
"runtime": self._human_input_runtime,
},
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
include_prompt_message_serializer=True,
include_retriever_attachment_loader=True,
include_jinja2_template_renderer=True,
),
BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: {
"unstructured_api_config": self._document_extractor_unstructured_api_config,
@@ -345,15 +370,26 @@ class DifyNodeFactory(NodeFactory):
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
include_prompt_message_serializer=True,
include_retriever_attachment_loader=False,
include_jinja2_template_renderer=False,
),
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
wrap_model_instance=True,
include_http_client=False,
include_llm_file_saver=False,
include_prompt_message_serializer=True,
include_retriever_attachment_loader=False,
include_jinja2_template_renderer=False,
),
BuiltinNodeTypes.TOOL: lambda: {
"tool_file_manager_factory": self._http_request_tool_file_manager_factory(),
"tool_file_manager_factory": self._bound_tool_file_manager_factory(),
"runtime": self._tool_runtime,
},
BuiltinNodeTypes.AGENT: lambda: {
"strategy_resolver": self._agent_strategy_resolver,
@@ -387,7 +423,12 @@ class DifyNodeFactory(NodeFactory):
*,
node_class: type[Node],
node_data: BaseNodeData,
wrap_model_instance: bool,
include_http_client: bool,
include_llm_file_saver: bool,
include_prompt_message_serializer: bool,
include_retriever_attachment_loader: bool,
include_jinja2_template_renderer: bool,
) -> dict[str, object]:
validated_node_data = cast(
LLMCompatibleNodeData,
@@ -397,49 +438,33 @@ class DifyNodeFactory(NodeFactory):
node_init_kwargs: dict[str, object] = {
"credentials_provider": self._llm_credentials_provider,
"model_factory": self._llm_model_factory,
"model_instance": model_instance,
"model_instance": DifyPreparedLLM(model_instance) if wrap_model_instance else model_instance,
"memory": self._build_memory_for_llm_node(
node_data=validated_node_data,
model_instance=model_instance,
),
}
if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
if validated_node_data.type == BuiltinNodeTypes.QUESTION_CLASSIFIER:
node_init_kwargs["template_renderer"] = self._llm_template_renderer
if include_http_client:
node_init_kwargs["http_client"] = self._http_request_http_client
if include_llm_file_saver:
node_init_kwargs["llm_file_saver"] = self._llm_file_saver
if include_prompt_message_serializer:
node_init_kwargs["prompt_message_serializer"] = self._prompt_message_serializer
if include_retriever_attachment_loader:
node_init_kwargs["retriever_attachment_loader"] = self._retriever_attachment_loader
if include_jinja2_template_renderer:
node_init_kwargs["jinja2_template_renderer"] = self._jinja2_template_renderer
return node_init_kwargs
def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance:
node_data_model = node_data.model
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name)
model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
provider_model_bundle = model_instance.provider_model_bundle
provider_model = provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name,
model_type=ModelType.LLM,
model_instance, _ = fetch_model_config(
node_data_model=node_data_model,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
completion_params = dict(node_data_model.completion_params)
stop = completion_params.pop("stop", [])
if not isinstance(stop, list):
stop = []
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
model_instance.provider = node_data_model.provider
model_instance.model_name = node_data_model.name
model_instance.credentials = credentials
model_instance.parameters = completion_params
model_instance.stop = tuple(stop)
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
return model_instance

View File

@@ -0,0 +1,532 @@
from __future__ import annotations
from collections.abc import Generator, Mapping, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelInstance
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
from core.plugin.impl.plugin import PluginInstaller
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType
from core.tools.errors import ToolInvokeError
from core.tools.signature import sign_upload_file
from core.tools.tool_engine import ToolEngine
from core.tools.tool_file_manager import ToolFileManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.file import FileTransferMethod, FileType
from dify_graph.model_runtime.entities import LLMMode
from dify_graph.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
LLMUsage,
)
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.nodes.human_input.entities import DeliveryChannelConfig, apply_debug_email_recipient
from dify_graph.nodes.llm.runtime_protocols import (
PreparedLLMProtocol,
PromptMessageSerializerProtocol,
RetrieverAttachmentLoaderProtocol,
)
from dify_graph.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol
from dify_graph.nodes.runtime import (
HumanInputNodeRuntimeProtocol,
ToolNodeRuntimeProtocol,
)
from dify_graph.nodes.tool.exc import ToolNodeError, ToolRuntimeInvocationError, ToolRuntimeResolutionError
from dify_graph.nodes.tool_runtime_entities import (
ToolRuntimeHandle,
ToolRuntimeMessage,
ToolRuntimeParameter,
)
from extensions.ext_database import db
from factories import file_factory
from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage
from dify_graph.file import File
from dify_graph.nodes.llm.file_saver import LLMFileSaver
from dify_graph.nodes.tool.entities import ToolNodeData
def resolve_dify_run_context(run_context: Mapping[str, Any] | DifyRunContext) -> DifyRunContext:
if isinstance(run_context, DifyRunContext):
return run_context
raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY)
if raw_ctx is None:
raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}")
if isinstance(raw_ctx, DifyRunContext):
return raw_ctx
return DifyRunContext.model_validate(raw_ctx)
class DifyFileReferenceFactory(FileReferenceFactoryProtocol):
def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None:
self._run_context = resolve_dify_run_context(run_context)
def build_from_mapping(self, *, mapping: Mapping[str, Any]):
return file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self._run_context.tenant_id,
)
class DifyPreparedLLM(PreparedLLMProtocol):
"""Workflow-layer adapter that hides the full `ModelInstance` API from `dify_graph` nodes."""
def __init__(self, model_instance: ModelInstance) -> None:
self._model_instance = model_instance
@property
def provider(self) -> str:
return self._model_instance.provider
@property
def model_name(self) -> str:
return self._model_instance.model_name
@property
def parameters(self) -> Mapping[str, Any]:
return self._model_instance.parameters
@property
def stop(self) -> Sequence[str] | None:
return self._model_instance.stop
def get_model_schema(self) -> AIModelEntity:
model_schema = cast(LargeLanguageModel, self._model_instance.model_type_instance).get_model_schema(
self._model_instance.model_name,
self._model_instance.credentials,
)
if model_schema is None:
raise ValueError(f"Model schema not found for {self._model_instance.model_name}")
return model_schema
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
return self._model_instance.get_llm_num_tokens(prompt_messages)
def invoke_llm(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: Mapping[str, Any],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: bool,
) -> LLMResult | Generator[LLMResultChunk, None, None]:
return self._model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters=dict(model_parameters),
tools=list(tools or []),
stop=list(stop or []),
stream=stream,
)
def invoke_llm_with_structured_output(
self,
*,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping[str, Any],
stop: Sequence[str] | None,
stream: bool,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
return invoke_llm_with_structured_output(
provider=self.provider,
model_schema=self.get_model_schema(),
model_instance=self._model_instance,
prompt_messages=prompt_messages,
json_schema=json_schema,
model_parameters=model_parameters,
stop=list(stop or []),
stream=stream,
)
def is_structured_output_parse_error(self, error: Exception) -> bool:
return isinstance(error, OutputParserError)
class DifyPromptMessageSerializer(PromptMessageSerializerProtocol):
def serialize(
self,
*,
model_mode: LLMMode,
prompt_messages: Sequence[PromptMessage],
) -> Any:
return PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_mode,
prompt_messages=prompt_messages,
)
class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol):
"""Resolve retriever attachments through Dify persistence and return graph file references."""
def __init__(self, *, file_reference_factory: FileReferenceFactoryProtocol) -> None:
self._file_reference_factory = file_reference_factory
def load(self, *, segment_id: str) -> Sequence[File]:
with Session(db.engine, expire_on_commit=False) as session:
attachments_with_bindings = session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(SegmentAttachmentBinding.segment_id == segment_id)
).all()
return [
self._file_reference_factory.build_from_mapping(
mapping={
"id": upload_file.id,
"filename": upload_file.name,
"extension": "." + upload_file.extension,
"mime_type": upload_file.mime_type,
"type": FileType.IMAGE,
"transfer_method": FileTransferMethod.LOCAL_FILE,
"remote_url": upload_file.source_url,
"related_id": upload_file.id,
"upload_file_id": upload_file.id,
"size": upload_file.size,
"storage_key": upload_file.key,
"url": sign_upload_file(upload_file.id, upload_file.extension),
}
)
for _, upload_file in attachments_with_bindings
]
class DifyToolFileManager(ToolFileManagerProtocol):
def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None:
self._run_context = resolve_dify_run_context(run_context)
self._manager = ToolFileManager()
def create_file_by_raw(
self,
*,
conversation_id: str | None,
file_binary: bytes,
mimetype: str,
filename: str | None = None,
) -> Any:
return self._manager.create_file_by_raw(
user_id=self._run_context.user_id,
tenant_id=self._run_context.tenant_id,
conversation_id=conversation_id,
file_binary=file_binary,
mimetype=mimetype,
filename=filename,
)
def get_file_generator_by_tool_file_id(self, tool_file_id: str):
return self._manager.get_file_generator_by_tool_file_id(tool_file_id)
@dataclass(frozen=True, slots=True)
class _WorkflowToolRuntimeSpec:
provider_type: CoreToolProviderType
provider_id: str
tool_name: str
tool_configurations: dict[str, Any]
credential_id: str | None = None
class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None:
self._run_context = resolve_dify_run_context(run_context)
self._file_reference_factory = DifyFileReferenceFactory(self._run_context)
@property
def file_reference_factory(self) -> FileReferenceFactoryProtocol:
return self._file_reference_factory
def build_file_reference(self, *, mapping: Mapping[str, Any]):
return self._file_reference_factory.build_from_mapping(mapping=mapping)
def get_runtime(
self,
*,
node_id: str,
node_data: ToolNodeData,
variable_pool,
) -> ToolRuntimeHandle:
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(
self._run_context.tenant_id,
self._run_context.app_id,
node_id,
self._build_tool_runtime_spec(node_data),
self._run_context.invoke_from,
variable_pool,
)
except ToolNodeError:
raise
except Exception as exc:
raise ToolRuntimeResolutionError(str(exc)) from exc
return ToolRuntimeHandle(raw=tool_runtime)
def get_runtime_parameters(
self,
*,
tool_runtime: ToolRuntimeHandle,
) -> Sequence[ToolRuntimeParameter]:
tool = self._tool_from_handle(tool_runtime)
return [
ToolRuntimeParameter(name=parameter.name, required=parameter.required)
for parameter in (tool.get_merged_runtime_parameters() or [])
]
def invoke(
self,
*,
tool_runtime: ToolRuntimeHandle,
tool_parameters: Mapping[str, Any],
workflow_call_depth: int,
conversation_id: str | None,
provider_name: str,
) -> Generator[ToolRuntimeMessage, None, None]:
tool = self._tool_from_handle(tool_runtime)
callback = DifyWorkflowCallbackHandler()
try:
messages = ToolEngine.generic_invoke(
tool=tool,
tool_parameters=dict(tool_parameters),
user_id=self._run_context.user_id,
workflow_tool_callback=callback,
workflow_call_depth=workflow_call_depth,
app_id=self._run_context.app_id,
conversation_id=conversation_id,
)
except Exception as exc:
raise self._map_invocation_exception(exc, provider_name=provider_name) from exc
transformed_messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=self._run_context.user_id,
tenant_id=self._run_context.tenant_id,
conversation_id=None,
)
return self._adapt_messages(transformed_messages, provider_name=provider_name)
def get_usage(
self,
*,
tool_runtime: ToolRuntimeHandle,
) -> LLMUsage:
latest = getattr(tool_runtime.raw, "latest_usage", None)
if isinstance(latest, LLMUsage):
return latest
if isinstance(latest, dict):
return LLMUsage.model_validate(latest)
return LLMUsage.empty_usage()
def resolve_provider_icons(
self,
*,
provider_name: str,
default_icon: str | None = None,
) -> tuple[str | None, str | None]:
icon = default_icon
icon_dark = None
manager = PluginInstaller()
plugins = manager.list_plugins(self._run_context.tenant_id)
try:
current_plugin = next(plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == provider_name)
icon = current_plugin.declaration.icon
except StopIteration:
pass
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
self._run_context.user_id,
self._run_context.tenant_id,
)
if provider.name == provider_name
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
return icon, icon_dark
@staticmethod
def _tool_from_handle(tool_runtime: ToolRuntimeHandle) -> Tool:
return cast("Tool", tool_runtime.raw)
@staticmethod
def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec:
return _WorkflowToolRuntimeSpec(
provider_type=CoreToolProviderType(node_data.provider_type.value),
provider_id=node_data.provider_id,
tool_name=node_data.tool_name,
tool_configurations=dict(node_data.tool_configurations),
credential_id=node_data.credential_id,
)
def _adapt_messages(
self,
messages: Generator[CoreToolInvokeMessage, None, None],
*,
provider_name: str,
) -> Generator[ToolRuntimeMessage, None, None]:
try:
for message in messages:
yield self._convert_message(message)
except Exception as exc:
raise self._map_invocation_exception(exc, provider_name=provider_name) from exc
def _convert_message(self, message: CoreToolInvokeMessage) -> ToolRuntimeMessage:
graph_message_type = ToolRuntimeMessage.MessageType(message.type.value)
graph_message = self._convert_message_payload(message.message)
graph_meta = message.meta.copy() if message.meta is not None else None
return ToolRuntimeMessage(type=graph_message_type, message=graph_message, meta=graph_meta)
def _convert_message_payload(
self,
message: CoreToolInvokeMessage.TextMessage
| CoreToolInvokeMessage.JsonMessage
| CoreToolInvokeMessage.BlobChunkMessage
| CoreToolInvokeMessage.BlobMessage
| CoreToolInvokeMessage.LogMessage
| CoreToolInvokeMessage.FileMessage
| CoreToolInvokeMessage.VariableMessage
| CoreToolInvokeMessage.RetrieverResourceMessage
| None,
) -> (
ToolRuntimeMessage.TextMessage
| ToolRuntimeMessage.JsonMessage
| ToolRuntimeMessage.BlobChunkMessage
| ToolRuntimeMessage.BlobMessage
| ToolRuntimeMessage.LogMessage
| ToolRuntimeMessage.FileMessage
| ToolRuntimeMessage.VariableMessage
| ToolRuntimeMessage.RetrieverResourceMessage
| None
):
if message is None:
return None
from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage
if isinstance(message, CoreToolInvokeMessage.TextMessage):
return ToolRuntimeMessage.TextMessage(text=message.text)
if isinstance(message, CoreToolInvokeMessage.JsonMessage):
return ToolRuntimeMessage.JsonMessage(
json_object=message.json_object,
suppress_output=message.suppress_output,
)
if isinstance(message, CoreToolInvokeMessage.BlobMessage):
return ToolRuntimeMessage.BlobMessage(blob=message.blob)
if isinstance(message, CoreToolInvokeMessage.BlobChunkMessage):
return ToolRuntimeMessage.BlobChunkMessage(
id=message.id,
sequence=message.sequence,
total_length=message.total_length,
blob=message.blob,
end=message.end,
)
if isinstance(message, CoreToolInvokeMessage.FileMessage):
return ToolRuntimeMessage.FileMessage(file_marker=message.file_marker)
if isinstance(message, CoreToolInvokeMessage.VariableMessage):
return ToolRuntimeMessage.VariableMessage(
variable_name=message.variable_name,
variable_value=message.variable_value,
stream=message.stream,
)
if isinstance(message, CoreToolInvokeMessage.LogMessage):
return ToolRuntimeMessage.LogMessage(
id=message.id,
label=message.label,
parent_id=message.parent_id,
error=message.error,
status=ToolRuntimeMessage.LogMessage.LogStatus(message.status.value),
data=dict(message.data),
metadata=dict(message.metadata),
)
if isinstance(message, CoreToolInvokeMessage.RetrieverResourceMessage):
retriever_resources = [
resource.model_dump() if hasattr(resource, "model_dump") else dict(resource)
for resource in message.retriever_resources
]
return ToolRuntimeMessage.RetrieverResourceMessage(
retriever_resources=retriever_resources,
context=message.context,
)
raise TypeError(f"unsupported tool message payload: {type(message).__name__}")
@staticmethod
def _map_invocation_exception(exc: Exception, *, provider_name: str) -> ToolNodeError:
if isinstance(exc, ToolNodeError):
return exc
if isinstance(exc, PluginInvokeError):
return ToolRuntimeInvocationError(exc.to_user_friendly_error(plugin_name=provider_name))
if isinstance(exc, PluginDaemonClientSideError):
return ToolRuntimeInvocationError(f"Failed to invoke tool, error: {exc.description}")
if isinstance(exc, ToolInvokeError):
return ToolRuntimeInvocationError(f"Failed to invoke tool {provider_name}: {exc}")
return ToolRuntimeInvocationError(str(exc))
class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol):
def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None:
self._run_context = resolve_dify_run_context(run_context)
def invoke_source(self) -> str:
invoke_from = self._run_context.invoke_from
if isinstance(invoke_from, str):
return invoke_from
return str(getattr(invoke_from, "value", invoke_from))
def apply_delivery_runtime(
self,
*,
methods: Sequence[DeliveryChannelConfig],
) -> Sequence[DeliveryChannelConfig]:
return [
apply_debug_email_recipient(
method,
enabled=self.invoke_source() == "debugger",
user_id=self._run_context.user_id,
)
for method in methods
]
def console_actor_id(self) -> str | None:
return self._run_context.user_id
def build_dify_llm_file_saver(
*,
run_context: Mapping[str, Any] | DifyRunContext,
http_client: HttpClientProtocol,
) -> LLMFileSaver:
from dify_graph.nodes.llm.file_saver import FileSaverImpl
return FileSaverImpl(
tool_file_manager=DifyToolFileManager(run_context),
file_reference_factory=DifyFileReferenceFactory(run_context),
http_client=http_client,
)

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
@@ -59,7 +60,7 @@ class AgentNode(Node[AgentNodeData]):
return "1"
def populate_start_event(self, event) -> None:
dify_ctx = self.require_dify_context()
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
event.extras["agent_strategy"] = {
"name": self.node_data.agent_strategy_name,
"icon": self._presentation_provider.get_icon(
@@ -71,7 +72,7 @@ class AgentNode(Node[AgentNodeData]):
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
dify_ctx = self.require_dify_context()
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
try:
strategy = self._strategy_resolver.resolve(
@@ -97,6 +98,7 @@ class AgentNode(Node[AgentNodeData]):
node_data=self.node_data,
strategy=strategy,
tenant_id=dify_ctx.tenant_id,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
invoke_from=dify_ctx.invoke_from,
)
@@ -106,6 +108,7 @@ class AgentNode(Node[AgentNodeData]):
node_data=self.node_data,
strategy=strategy,
tenant_id=dify_ctx.tenant_id,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
invoke_from=dify_ctx.invoke_from,
for_log=True,

View File

@@ -14,7 +14,7 @@ from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.plugin.entities.request import InvokeCredentials
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager
from dify_graph.enums import SystemVariableKey
@@ -38,6 +38,7 @@ class AgentRuntimeSupport:
node_data: AgentNodeData,
strategy: ResolvedAgentStrategy,
tenant_id: str,
user_id: str,
app_id: str,
invoke_from: Any,
for_log: bool = False,
@@ -174,7 +175,11 @@ class AgentRuntimeSupport:
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
model_instance, model_schema = self.fetch_model(
tenant_id=tenant_id,
user_id=user_id,
value=value,
)
history_prompt_messages = []
if node_data.memory:
memory = self.fetch_memory(
@@ -232,8 +237,14 @@ class AgentRuntimeSupport:
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
def fetch_model(
self,
*,
tenant_id: str,
user_id: str,
value: dict[str, Any],
) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id)
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=value.get("provider", ""),
@@ -246,7 +257,7 @@ class AgentRuntimeSupport:
)
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
tenant_id=tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),

View File

@@ -1,6 +1,7 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.datasource.datasource_manager import DatasourceManager
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
@@ -50,8 +51,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
"""
Run the datasource node
"""
dify_ctx = self.require_dify_context()
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])

View File

@@ -9,6 +9,7 @@ from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from dify_graph.entities import GraphInitParams
@@ -160,7 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
) -> tuple[list[Source], LLMUsage]:
dify_ctx = self.require_dify_context()
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
dataset_ids = node_data.dataset_ids
query = variables.get("query")
attachments = variables.get("attachments")

View File

@@ -9,9 +9,9 @@ from dify_graph.enums import NodeExecutionType
from dify_graph.file import FileTransferMethod
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.protocols import FileReferenceFactoryProtocol
from dify_graph.variables.types import SegmentType
from dify_graph.variables.variables import FileVariable
from factories import file_factory
from factories.variable_factory import build_segment_with_type
from .entities import ContentType, WebhookData
@@ -23,6 +23,13 @@ class TriggerWebhookNode(Node[WebhookData]):
node_type = TRIGGER_WEBHOOK_NODE_TYPE
execution_type = NodeExecutionType.ROOT
_file_reference_factory: FileReferenceFactoryProtocol
def post_init(self) -> None:
from core.workflow.node_runtime import DifyFileReferenceFactory
self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context)
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
@@ -70,7 +77,6 @@ class TriggerWebhookNode(Node[WebhookData]):
)
def generate_file_var(self, param_name: str, file: dict):
dify_ctx = self.require_dify_context()
related_id = file.get("related_id")
transfer_method_value = file.get("transfer_method")
if transfer_method_value:
@@ -84,10 +90,7 @@ class TriggerWebhookNode(Node[WebhookData]):
file["datasource_file_id"] = related_id
try:
file_obj = file_factory.build_from_mapping(
mapping=file,
tenant_id=dify_ctx.tenant_id,
)
file_obj = self._file_reference_factory.build_from_mapping(mapping=file)
file_segment = build_segment_with_type(SegmentType.FILE, file_obj)
return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name])
except ValueError: