feat: extract model runtime

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-15 15:34:47 +08:00
parent 3d5a29462e
commit fbb74a4af9
178 changed files with 4343 additions and 2134 deletions

View File

@@ -33,20 +33,6 @@ ignore_imports =
# TODO(QuantumGhost): fix the import violation later
dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities
[importlinter:contract:workflow-infrastructure-dependencies]
name = Workflow Infrastructure Dependencies
type = forbidden
source_modules =
dify_graph
forbidden_modules =
extensions.ext_database
extensions.ext_redis
allow_indirect_imports = True
ignore_imports =
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
[importlinter:contract:workflow-external-imports]
name = Workflow External Imports
type = forbidden
@@ -89,45 +75,7 @@ forbidden_modules =
core.trigger
core.variables
ignore_imports =
dify_graph.nodes.llm.llm_utils -> core.model_manager
dify_graph.nodes.llm.protocols -> core.model_manager
dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model
dify_graph.nodes.llm.node -> core.tools.signature
dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
dify_graph.nodes.tool.tool_node -> core.tools.tool_engine
dify_graph.nodes.tool.tool_node -> core.tools.tool_manager
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
dify_graph.nodes.llm.node -> core.model_manager
dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.llm.node -> models.dataset
dify_graph.nodes.llm.file_saver -> core.tools.signature
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
dify_graph.nodes.tool.tool_node -> core.tools.errors
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.llm.node -> models.model
dify_graph.nodes.tool.tool_node -> services
dify_graph.model_runtime.model_providers.__base.ai_model -> configs
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
dify_graph.model_runtime.model_providers.__base.large_language_model -> configs
dify_graph.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type
dify_graph.model_runtime.model_providers.model_provider_factory -> configs
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
dify_graph.model_runtime.model_providers.model_provider_factory -> models.provider_ids
[importlinter:contract:rsc]
name = RSC

View File

@@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar
if TYPE_CHECKING:
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.trigger.provider import PluginTriggerProviderController
@@ -20,14 +19,6 @@ plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderControl
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
ContextVar("plugin_model_providers")
)
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock")
)
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
)

View File

@@ -25,7 +25,7 @@ from controllers.console.wraps import (
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
@@ -331,7 +331,7 @@ class DatasetListApi(Resource):
)
# check embedding setting
provider_manager = ProviderManager()
provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id)
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -445,7 +445,7 @@ class DatasetApi(Resource):
data.update({"partial_member_list": part_users_list})
# check embedding setting
provider_manager = ProviderManager()
provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id)
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)

View File

@@ -452,7 +452,7 @@ class DatasetInitApi(Resource):
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=knowledge_config.embedding_model_provider,

View File

@@ -282,7 +282,7 @@ class DatasetDocumentSegmentApi(Resource):
if dataset.indexing_technique == "high_quality":
# check embedding model setting
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -335,7 +335,7 @@ class DatasetDocumentSegmentAddApi(Resource):
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -386,7 +386,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
if dataset.indexing_technique == "high_quality":
# check embedding model setting
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -571,7 +571,7 @@ class ChildChunkAddApi(Resource):
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,

View File

@@ -282,14 +282,18 @@ class ModelProviderModelCredentialApi(Resource):
)
if args.config_from == "predefined-model":
available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
tenant_id=tenant_id, provider_name=provider
available_credentials = model_provider_service.get_provider_available_credentials(
tenant_id=tenant_id,
provider=provider,
)
else:
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
normalized_model_type = args.model_type.to_origin_model_type()
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
available_credentials = model_provider_service.get_provider_model_available_credentials(
tenant_id=tenant_id,
provider=provider,
model_type=normalized_model_type,
model=args.model,
)
return jsonable_encoder(

View File

@@ -14,7 +14,7 @@ from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_rate_limit_check,
)
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from dify_graph.model_runtime.entities.model_entities import ModelType
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import DataSetTag
@@ -139,10 +139,10 @@ class DatasetListApi(DatasetApiResource):
query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all
)
# check embedding setting
provider_manager = ProviderManager()
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
provider_manager = create_plugin_provider_manager(tenant_id=cid)
configurations = provider_manager.get_configurations(tenant_id=cid)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -253,10 +253,10 @@ class DatasetApi(DatasetApiResource):
raise Forbidden(str(e))
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
# check embedding setting
provider_manager = ProviderManager()
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
provider_manager = create_plugin_provider_manager(tenant_id=cid)
configurations = provider_manager.get_configurations(tenant_id=cid)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)

View File

@@ -105,7 +105,7 @@ class SegmentApi(DatasetApiResource):
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -159,7 +159,7 @@ class SegmentApi(DatasetApiResource):
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -265,7 +265,7 @@ class DatasetSegmentApi(DatasetApiResource):
if dataset.indexing_technique == "high_quality":
# check embedding model setting
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -360,7 +360,7 @@ class ChildChunkApi(DatasetApiResource):
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,

View File

@@ -122,7 +122,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tools=[],
stop=app_generate_entity.model_conf.stop,
stream=True,
user=self.user_id,
callbacks=[],
)

View File

@@ -96,7 +96,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tools=prompt_messages_tools,
stop=app_generate_entity.model_conf.stop,
stream=self.stream_tool_call,
user=self.user_id,
callbacks=[],
)

View File

@@ -4,7 +4,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from dify_graph.model_runtime.entities.llm_entities import LLMMode
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@@ -21,7 +21,7 @@ class ModelConfigConverter:
"""
model_config = app_config.model
provider_manager = ProviderManager()
provider_manager = create_plugin_provider_manager(tenant_id=app_config.tenant_id)
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
)

View File

@@ -2,9 +2,8 @@ from collections.abc import Mapping
from typing import Any
from core.app.app_config.entities import ModelConfigEntity
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from models.model import AppModelConfigDict
from models.provider_ids import ModelProviderID
@@ -55,7 +54,7 @@ class ModelConfigManager:
raise ValueError("model must be of object type")
# model.provider
model_provider_factory = ModelProviderFactory(tenant_id)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities]
if "provider" not in config["model"]:
@@ -71,7 +70,7 @@ class ModelConfigManager:
if "name" not in config["model"]:
raise ValueError("model.name is required")
provider_manager = ProviderManager()
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
models = provider_manager.get_configurations(tenant_id).get_models(
provider=config["model"]["provider"], model_type=ModelType.LLM
)

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.file import File, FileUploadConfig
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
@@ -15,6 +14,9 @@ if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
DIFY_RUN_CONTEXT_KEY = "_dify"
class UserFrom(StrEnum):
ACCOUNT = "account"
END_USER = "end-user"

View File

@@ -2,23 +2,34 @@ from __future__ import annotations
from typing import Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
from core.app.llm.protocols import CredentialsProvider, ModelFactory
from core.errors.error import ProviderTokenNotInitError
from core.model_manager import ModelInstance, ModelManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.nodes.llm.entities import ModelConfig
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
class DifyCredentialsProvider:
tenant_id: str
provider_manager: ProviderManager
def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None:
self.tenant_id = tenant_id
self.provider_manager = provider_manager or ProviderManager()
def __init__(
self,
*,
run_context: DifyRunContext,
provider_manager: ProviderManager | None = None,
) -> None:
self.tenant_id = run_context.tenant_id
if provider_manager is None:
provider_manager = create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
self.provider_manager = provider_manager
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
@@ -42,9 +53,21 @@ class DifyModelFactory:
tenant_id: str
model_manager: ModelManager
def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None:
self.tenant_id = tenant_id
self.model_manager = model_manager or ModelManager()
def __init__(
self,
*,
run_context: DifyRunContext,
model_manager: ModelManager | None = None,
) -> None:
self.tenant_id = run_context.tenant_id
if model_manager is None:
model_manager = ModelManager(
provider_manager=create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
)
self.model_manager = model_manager
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
return self.model_manager.get_model_instance(
@@ -55,11 +78,35 @@ class DifyModelFactory:
)
def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]:
return (
DifyCredentialsProvider(tenant_id=tenant_id),
DifyModelFactory(tenant_id=tenant_id),
def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsProvider, ModelFactory]:
"""Create LLM access adapters that share the same tenant-bound manager graph."""
provider_manager = create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
model_manager = ModelManager(provider_manager=provider_manager)
return (
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
DifyModelFactory(run_context=run_context, model_manager=model_manager),
)
def _normalize_completion_params(completion_params: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Split node-level completion params into provider parameters and stop sequences.
Workflow LLM-compatible nodes still consume runtime invocation settings from
``ModelInstance.parameters`` and ``ModelInstance.stop``. Keep the
``ModelInstance`` view and the returned config entity aligned here so callers
do not need to duplicate normalization logic.
"""
normalized_parameters = dict(completion_params)
stop = normalized_parameters.pop("stop", [])
if not isinstance(stop, list) or not all(isinstance(item, str) for item in stop):
stop = []
return normalized_parameters, stop
def fetch_model_config(
@@ -80,22 +127,18 @@ def fetch_model_config(
model_type=ModelType.LLM,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
raise ModelNotExistError(f"Model {node_data_model.name} does not exist.")
provider_model.raise_for_status()
completion_params = dict(node_data_model.completion_params)
stop = completion_params.pop("stop", [])
if not isinstance(stop, list):
stop = []
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
if model_schema is None:
raise ModelNotExistError(f"Model {node_data_model.name} schema does not exist.")
parameters, stop = _normalize_completion_params(node_data_model.completion_params)
model_instance.provider = node_data_model.provider
model_instance.model_name = node_data_model.name
model_instance.credentials = credentials
model_instance.parameters = completion_params
model_instance.parameters = parameters
model_instance.stop = tuple(stop)
return model_instance, ModelConfigWithCredentialsEntity(
@@ -103,8 +146,8 @@ def fetch_model_config(
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=provider_model_bundle,
credentials=credentials,
parameters=completion_params,
parameters=parameters,
stop=stop,
provider_model_bundle=provider_model_bundle,
)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from typing import Any, Protocol
from core.model_manager import ModelInstance
class CredentialsProvider(Protocol):
"""Workflow-layer port for loading runtime credentials for a provider/model pair."""
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
"""Return credentials for the target provider/model or raise a domain error."""
...
class ModelFactory(Protocol):
"""Workflow-layer port for creating mutable ModelInstance objects."""
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
"""Create a model instance that is ready for workflow-side hydration."""
...

View File

@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, cast, final
from typing_extensions import override
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
@@ -75,7 +76,7 @@ class LLMQuotaLayer(GraphEngineLayer):
return
try:
dify_ctx = node.require_dify_context()
dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
deduct_llm_quota(
tenant_id=dify_ctx.tenant_id,
model_instance=model_instance,

View File

@@ -25,12 +25,10 @@ class AudioTrunk:
self.status = status
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
def _invoice_tts(text_content: str, model_instance: ModelInstance, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
)
return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice)
def _process_future(
@@ -62,7 +60,7 @@ class AppGeneratorTTSPublisher:
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self.match = re.compile(r"[。.!?]")
self.model_manager = ModelManager()
self.model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id, user_id="responding_tts")
self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.TTS
)
@@ -89,7 +87,7 @@ class AppGeneratorTTSPublisher:
if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit(
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
_invoice_tts, self.msg_text, self.model_instance, self.voice
)
future_queue.put(futures_result)
break
@@ -117,9 +115,7 @@ class AppGeneratorTTSPublisher:
if len(sentence_arr) >= min(self.max_sentence, 7):
self.max_sentence += 1
text_content = "".join(sentence_arr)
futures_result = self.executor.submit(
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
)
futures_result = self.executor.submit(_invoice_tts, text_content, self.model_instance, self.voice)
future_queue.put(futures_result)
if isinstance(text_tmp, str):
self.msg_text = text_tmp

View File

@@ -1,10 +1,5 @@
from enum import StrEnum, auto
"""Compatibility wrapper for the runtime embedding input enum."""
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType
class EmbeddingInputType(StrEnum):
"""
Enum for embedding input type.
"""
DOCUMENT = auto()
QUERY = auto()
__all__ = ["EmbeddingInputType"]

View File

@@ -19,6 +19,7 @@ from core.entities.provider_entities import (
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from dify_graph.model_runtime.entities.provider_entities import (
ConfigurateMethod,
@@ -27,7 +28,6 @@ from dify_graph.model_runtime.entities.provider_entities import (
ProviderEntity,
)
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.enums import CredentialSourceType
@@ -343,7 +343,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = ModelProviderFactory(self.tenant_id)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
@@ -902,7 +902,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = ModelProviderFactory(self.tenant_id)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
@@ -1388,7 +1388,7 @@ class ProviderConfiguration(BaseModel):
:param model_type: model type
:return:
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
@@ -1397,7 +1397,7 @@ class ProviderConfiguration(BaseModel):
"""
Get model schema
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
return model_provider_factory.get_model_schema(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
@@ -1499,7 +1499,7 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
model_types: list[ModelType] = []

View File

@@ -4,10 +4,10 @@ from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_hosting_provider import hosting_configuration
from models.provider import ProviderType
@@ -41,7 +41,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
text_chunk = secrets.choice(text_chunks)
try:
model_provider_factory = ModelProviderFactory(tenant_id)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
# Get model instance of LLM
model_type_instance = model_provider_factory.get_model_type_instance(

View File

@@ -50,7 +50,10 @@ logger = logging.getLogger(__name__)
class IndexingRunner:
def __init__(self):
self.storage = storage
self.model_manager = ModelManager()
@staticmethod
def _get_model_manager(tenant_id: str) -> ModelManager:
return ModelManager.for_tenant(tenant_id=tenant_id)
def _handle_indexing_error(self, document_id: str, error: Exception) -> None:
"""Handle indexing errors by updating document status."""
@@ -291,20 +294,20 @@ class IndexingRunner:
raise ValueError("Dataset not found.")
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
embedding_model_instance = self._get_model_manager(tenant_id).get_model_instance(
tenant_id=tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
if indexing_technique == "high_quality":
embedding_model_instance = self.model_manager.get_default_model_instance(
embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
@@ -574,7 +577,7 @@ class IndexingRunner:
embedding_model_instance = None
if dataset.indexing_technique == "high_quality":
embedding_model_instance = self.model_manager.get_model_instance(
embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
@@ -766,14 +769,14 @@ class IndexingRunner:
embedding_model_instance = None
if dataset.indexing_technique == "high_quality":
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)

View File

@@ -62,7 +62,7 @@ class LLMGenerator:
prompt += query + "\n"
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -120,7 +120,7 @@ class LLMGenerator:
prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions})
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -172,7 +172,7 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt_generate)]
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
@@ -219,7 +219,7 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)]
# get model instance
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -306,7 +306,7 @@ class LLMGenerator:
remove_template_variables=False,
)
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -337,7 +337,7 @@ class LLMGenerator:
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -362,7 +362,7 @@ class LLMGenerator:
@classmethod
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -536,7 +536,7 @@ class LLMGenerator:
injected_instruction = injected_instruction.replace(CURRENT, current or "null")
if ERROR_MESSAGE in injected_instruction:
injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null")
model_instance = ModelManager().get_model_instance(
model_instance = ModelManager.for_tenant(tenant_id=tenant_id).get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,

View File

@@ -55,7 +55,6 @@ def invoke_llm_with_structured_output(
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[True],
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
@overload
@@ -70,7 +69,6 @@ def invoke_llm_with_structured_output(
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[False],
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> LLMResultWithStructuredOutput: ...
@overload
@@ -85,7 +83,6 @@ def invoke_llm_with_structured_output(
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
@@ -99,7 +96,6 @@ def invoke_llm_with_structured_output(
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
"""
@@ -113,7 +109,6 @@ def invoke_llm_with_structured_output(
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
@@ -143,7 +138,6 @@ def invoke_llm_with_structured_output(
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)

View File

@@ -7,6 +7,7 @@ from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.errors.error import ProviderTokenNotInitError
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.callbacks.base_callback import Callback
from dify_graph.model_runtime.entities.llm_entities import LLMResult
@@ -30,7 +31,7 @@ logger = logging.getLogger(__name__)
class ModelInstance:
"""
Model instance class
Model instance class.
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
@@ -110,7 +111,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[True] = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Generator: ...
@@ -122,7 +122,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[False] = False,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> LLMResult: ...
@@ -134,7 +133,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Union[LLMResult, Generator]: ...
@@ -145,7 +143,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Union[LLMResult, Generator]:
"""
@@ -156,7 +153,6 @@ class ModelInstance:
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
@@ -173,7 +169,6 @@ class ModelInstance:
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
),
)
@@ -202,13 +197,12 @@ class ModelInstance:
)
def invoke_text_embedding(
self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
self, texts: list[str], input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
) -> EmbeddingResult:
"""
Invoke large language model
:param texts: texts to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
@@ -221,7 +215,6 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
texts=texts,
user=user,
input_type=input_type,
),
)
@@ -229,14 +222,12 @@ class ModelInstance:
def invoke_multimodal_embedding(
self,
multimodel_documents: list[dict],
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> EmbeddingResult:
"""
Invoke large language model
:param multimodel_documents: multimodel documents to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
@@ -249,7 +240,6 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
user=user,
input_type=input_type,
),
)
@@ -279,7 +269,6 @@ class ModelInstance:
docs: list[str],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model
@@ -288,7 +277,6 @@ class ModelInstance:
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
@@ -303,7 +291,6 @@ class ModelInstance:
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user,
),
)
@@ -313,7 +300,6 @@ class ModelInstance:
docs: list[dict],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model
@@ -322,7 +308,6 @@ class ModelInstance:
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
@@ -337,16 +322,14 @@ class ModelInstance:
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user,
),
)
def invoke_moderation(self, text: str, user: str | None = None) -> bool:
def invoke_moderation(self, text: str) -> bool:
"""
Invoke moderation model
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
if not isinstance(self.model_type_instance, ModerationModel):
@@ -358,16 +341,14 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
text=text,
user=user,
),
)
def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str:
def invoke_speech2text(self, file: IO[bytes]) -> str:
"""
Invoke large language model
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
@@ -379,18 +360,15 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
file=file,
user=user,
),
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]:
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
"""
Invoke large language tts model
:param content_text: text content to be translated
:param tenant_id: user tenant id
:param voice: model timbre
:param user: unique user id
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, TTSModel):
@@ -402,8 +380,6 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
user=user,
tenant_id=tenant_id,
voice=voice,
),
)
@@ -477,10 +453,20 @@ class ModelInstance:
class ModelManager:
def __init__(self):
self._provider_manager = ProviderManager()
def __init__(self, provider_manager: ProviderManager):
self._provider_manager = provider_manager
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
@classmethod
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
return cls(provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id))
def get_model_instance(
self,
tenant_id: str,
provider: str,
model_type: ModelType,
model: str,
) -> ModelInstance:
"""
Get model instance
:param tenant_id: tenant id
@@ -496,7 +482,8 @@ class ModelManager:
tenant_id=tenant_id, provider=provider, model_type=model_type
)
return ModelInstance(provider_model_bundle, model)
model_instance = ModelInstance(provider_model_bundle, model)
return model_instance
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
"""

View File

@@ -50,7 +50,7 @@ class OpenAIModeration(Moderation):
def _is_violated(self, inputs: dict):
text = "\n".join(str(inputs.values()))
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest"
)

View File

@@ -30,10 +30,27 @@ from dify_graph.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Tenant
class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
@staticmethod
def _get_bound_model_instance(
*,
tenant_id: str,
user_id: str | None,
provider: str,
model_type: ModelType,
model: str,
):
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
tenant_id=tenant_id,
provider=provider,
model_type=model_type,
model=model,
)
@classmethod
def invoke_llm(
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
@@ -41,8 +58,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke llm
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -55,7 +73,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
tools=payload.tools,
stop=payload.stop,
stream=True if payload.stream is None else payload.stream,
user=user_id,
)
if isinstance(response, Generator):
@@ -94,8 +111,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke llm with structured output
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -115,7 +133,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
tools=payload.tools,
stop=payload.stop,
stream=True if payload.stream is None else payload.stream,
user=user_id,
model_parameters=payload.completion_params,
)
@@ -156,18 +173,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke text embedding
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_text_embedding(
texts=payload.texts,
user=user_id,
)
response = model_instance.invoke_text_embedding(texts=payload.texts)
return response
@@ -176,8 +191,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke rerank
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -189,7 +205,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
docs=payload.docs,
score_threshold=payload.score_threshold,
top_n=payload.top_n,
user=user_id,
)
return response
@@ -199,20 +214,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke tts
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_tts(
content_text=payload.content_text,
tenant_id=tenant.id,
voice=payload.voice,
user=user_id,
)
response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice)
def handle() -> Generator[dict, None, None]:
for chunk in response:
@@ -225,8 +236,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke speech2text
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -238,10 +250,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
temp.flush()
temp.seek(0)
response = model_instance.invoke_speech2text(
file=temp,
user=user_id,
)
response = model_instance.invoke_speech2text(file=temp)
return {
"result": response,
@@ -252,18 +261,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke moderation
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_moderation(
text=payload.text,
user=user_id,
)
response = model_instance.invoke_moderation(text=payload.text)
return {
"result": response,

View File

@@ -1,6 +1,6 @@
import binascii
from collections.abc import Generator, Sequence
from typing import IO
from typing import IO, Any
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
@@ -16,12 +16,19 @@ from core.plugin.impl.base import BasePluginClient
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
class PluginModelClient(BasePluginClient):
@staticmethod
def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]:
payload: dict[str, Any] = {"data": data}
if user_id is not None:
payload["user_id"] = user_id
return payload
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
"""
Fetch model providers for the given tenant.
@@ -37,7 +44,7 @@ class PluginModelClient(BasePluginClient):
def get_model_schema(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model_type: str,
@@ -51,15 +58,15 @@ class PluginModelClient(BasePluginClient):
"POST",
f"plugin/{tenant_id}/dispatch/model/schema",
PluginModelSchemaEntity,
data={
"user_id": user_id,
"data": {
data=self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": model_type,
"model": model,
"credentials": credentials,
},
},
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
@@ -72,7 +79,7 @@ class PluginModelClient(BasePluginClient):
return None
def validate_provider_credentials(
self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict
) -> bool:
"""
validate the credentials of the provider
@@ -81,13 +88,13 @@ class PluginModelClient(BasePluginClient):
"POST",
f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
data=self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"credentials": credentials,
},
},
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
@@ -105,7 +112,7 @@ class PluginModelClient(BasePluginClient):
def validate_model_credentials(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model_type: str,
@@ -119,15 +126,15 @@ class PluginModelClient(BasePluginClient):
"POST",
f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
data=self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": model_type,
"model": model,
"credentials": credentials,
},
},
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
@@ -145,7 +152,7 @@ class PluginModelClient(BasePluginClient):
def invoke_llm(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -164,9 +171,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
type_=LLMResultChunk,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "llm",
"model": model,
@@ -177,7 +184,7 @@ class PluginModelClient(BasePluginClient):
"stop": stop,
"stream": stream,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -193,7 +200,7 @@ class PluginModelClient(BasePluginClient):
def get_llm_num_tokens(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model_type: str,
@@ -210,9 +217,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
type_=PluginLLMNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": model_type,
"model": model,
@@ -220,7 +227,7 @@ class PluginModelClient(BasePluginClient):
"prompt_messages": prompt_messages,
"tools": tools,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -236,7 +243,7 @@ class PluginModelClient(BasePluginClient):
def invoke_text_embedding(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -252,9 +259,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
type_=EmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "text-embedding",
"model": model,
@@ -262,7 +269,7 @@ class PluginModelClient(BasePluginClient):
"texts": texts,
"input_type": input_type,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -278,7 +285,7 @@ class PluginModelClient(BasePluginClient):
def invoke_multimodal_embedding(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -294,9 +301,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke",
type_=EmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "text-embedding",
"model": model,
@@ -304,7 +311,7 @@ class PluginModelClient(BasePluginClient):
"documents": documents,
"input_type": input_type,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -320,7 +327,7 @@ class PluginModelClient(BasePluginClient):
def get_text_embedding_num_tokens(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -335,16 +342,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
type_=PluginTextEmbeddingNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "text-embedding",
"model": model,
"credentials": credentials,
"texts": texts,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -360,7 +367,7 @@ class PluginModelClient(BasePluginClient):
def invoke_rerank(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -378,9 +385,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
type_=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "rerank",
"model": model,
@@ -390,7 +397,7 @@ class PluginModelClient(BasePluginClient):
"score_threshold": score_threshold,
"top_n": top_n,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -406,13 +413,13 @@ class PluginModelClient(BasePluginClient):
def invoke_multimodal_rerank(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
query: dict,
docs: list[dict],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None = None,
top_n: int | None = None,
) -> RerankResult:
@@ -424,9 +431,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke",
type_=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "rerank",
"model": model,
@@ -436,7 +443,7 @@ class PluginModelClient(BasePluginClient):
"score_threshold": score_threshold,
"top_n": top_n,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -451,7 +458,7 @@ class PluginModelClient(BasePluginClient):
def invoke_tts(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -467,9 +474,9 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "tts",
"model": model,
@@ -478,7 +485,7 @@ class PluginModelClient(BasePluginClient):
"content_text": content_text,
"voice": voice,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -496,7 +503,7 @@ class PluginModelClient(BasePluginClient):
def get_tts_model_voices(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -511,16 +518,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
type_=PluginVoicesResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "tts",
"model": model,
"credentials": credentials,
"language": language,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -540,7 +547,7 @@ class PluginModelClient(BasePluginClient):
def invoke_speech_to_text(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -555,16 +562,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "speech2text",
"model": model,
"credentials": credentials,
"file": binascii.hexlify(file.read()).decode(),
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,
@@ -580,7 +587,7 @@ class PluginModelClient(BasePluginClient):
def invoke_moderation(
self,
tenant_id: str,
user_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
@@ -595,16 +602,16 @@ class PluginModelClient(BasePluginClient):
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
type_=PluginBasicBooleanResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": "moderation",
"model": model,
"credentials": credentials,
"text": text,
},
}
)
),
headers={
"X-Plugin-ID": plugin_id,

View File

@@ -0,0 +1,490 @@
from __future__ import annotations
import hashlib
import logging
from collections.abc import Generator, Iterable, Sequence
from threading import Lock
from typing import IO, Any, Union
from pydantic import ValidationError
from redis import RedisError
from configs import dify_config
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.model import PluginModelClient
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
from dify_graph.model_runtime.runtime import ModelRuntime
from extensions.ext_redis import redis_client
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
class PluginModelRuntime(ModelRuntime):
"""Plugin-backed runtime adapter bound to tenant context and a default user."""
tenant_id: str
user_id: str | None
client: PluginModelClient
_provider_entities: tuple[ProviderEntity, ...] | None
_provider_entities_lock: Lock
def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None:
if client is None:
raise ValueError("client is required.")
self.tenant_id = tenant_id
self.user_id = user_id
self.client = client
self._provider_entities = None
self._provider_entities_lock = Lock()
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
if self._provider_entities is not None:
return self._provider_entities
with self._provider_entities_lock:
if self._provider_entities is None:
self._provider_entities = tuple(
self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id)
)
return self._provider_entities
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
provider_schema = self._get_provider_schema(provider)
if icon_type.lower() == "icon_small":
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
file_name = (
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
)
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
file_name = (
provider_schema.icon_small_dark.zh_Hans
if lang.lower() == "zh_hans"
else provider_schema.icon_small_dark.en_US
)
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")
if not file_name:
raise ValueError(f"Provider {provider} does not have icon.")
image_mime_types = {
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"png": "image/png",
"gif": "image/gif",
"bmp": "image/bmp",
"tiff": "image/tiff",
"tif": "image/tiff",
"webp": "image/webp",
"svg": "image/svg+xml",
"ico": "image/vnd.microsoft.icon",
"heif": "image/heif",
"heic": "image/heic",
}
extension = file_name.split(".")[-1]
mime_type = image_mime_types.get(extension, "image/png")
return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None:
plugin_id, provider_name = self._split_provider(provider)
self.client.validate_provider_credentials(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
credentials=credentials,
)
def validate_model_credentials(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> None:
plugin_id, provider_name = self._split_provider(provider)
self.client.validate_model_credentials(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
)
def get_model_schema(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> AIModelEntity | None:
cache_key = self._get_schema_cache_key(
provider=provider,
model_type=model_type,
model=model,
credentials=credentials,
)
cached_schema_json = None
try:
cached_schema_json = redis_client.get(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to read plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if cached_schema_json:
try:
return AIModelEntity.model_validate_json(cached_schema_json)
except ValidationError:
logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to delete invalid plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
plugin_id, provider_name = self._split_provider(provider)
schema = self.client.get_model_schema(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
)
if schema:
try:
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to write plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
return schema
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: bool,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_llm(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
model_parameters=model_parameters,
prompt_messages=list(prompt_messages),
tools=tools,
stop=list(stop) if stop else None,
stream=stream,
)
def get_llm_num_tokens(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: Sequence[PromptMessageTool] | None,
) -> int:
if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
return 0
plugin_id, provider_name = self._split_provider(provider)
return self.client.get_llm_num_tokens(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
prompt_messages=list(prompt_messages),
tools=list(tools) if tools else None,
)
def invoke_text_embedding(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
texts: list[str],
input_type: EmbeddingInputType,
) -> EmbeddingResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
texts=texts,
input_type=input_type,
)
def invoke_multimodal_embedding(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
documents: list[dict[str, Any]],
input_type: EmbeddingInputType,
) -> EmbeddingResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_multimodal_embedding(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
documents=documents,
input_type=input_type,
)
def get_text_embedding_num_tokens(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
texts: list[str],
) -> list[int]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.get_text_embedding_num_tokens(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
texts=texts,
)
def invoke_rerank(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
query: str,
docs: list[str],
score_threshold: float | None,
top_n: int | None,
) -> RerankResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_rerank(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_multimodal_rerank(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None,
top_n: int | None,
) -> RerankResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_multimodal_rerank(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_tts(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
content_text: str,
voice: str,
) -> Iterable[bytes]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_tts(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
content_text=content_text,
voice=voice,
)
def get_tts_model_voices(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
language: str | None,
) -> Any:
plugin_id, provider_name = self._split_provider(provider)
return self.client.get_tts_model_voices(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
language=language,
)
def invoke_speech_to_text(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
file: IO[bytes],
) -> str:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_speech_to_text(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
file=file,
)
def invoke_moderation(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
text: str,
) -> bool:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_moderation(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
text=text,
)
def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str:
"""
Expose a bare provider alias only for the canonical provider mapping.
Multiple plugins can publish the same short provider slug. If every
provider entity keeps that slug in ``provider_name``, callers that still
resolve by short name become order-dependent. Restrict the alias to the
provider selected by ``ModelProviderID`` so legacy short-name lookups
remain deterministic while the runtime surface stays canonical.
"""
try:
canonical_provider_id = ModelProviderID(provider.provider)
except ValueError:
return ""
if canonical_provider_id.plugin_id != provider.plugin_id:
return ""
if canonical_provider_id.provider_name != provider.provider:
return ""
return provider.provider
def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity:
declaration = provider.declaration.model_copy(deep=True)
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
declaration.provider_name = self._get_provider_short_name_alias(provider)
return declaration
def _get_provider_schema(self, provider: str) -> ProviderEntity:
providers = self.fetch_model_providers()
provider_entity = next((item for item in providers if item.provider == provider), None)
if provider_entity is None:
provider_entity = next((item for item in providers if provider == item.provider_name), None)
if provider_entity is None:
raise ValueError(f"Invalid provider: {provider}")
return provider_entity
def _get_schema_cache_key(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> str:
cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}"
sorted_credentials = sorted(credentials.items()) if credentials else []
return cache_key + ":".join(
[hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials]
)
def _split_provider(self, provider: str) -> tuple[str, str]:
provider_id = ModelProviderID(provider)
return provider_id.plugin_id, provider_id.provider_name

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from core.plugin.impl.model import PluginModelClient
if TYPE_CHECKING:
from core.model_manager import ModelManager
from core.plugin.impl.model_runtime import PluginModelRuntime
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -> PluginModelRuntime:
"""Create a plugin runtime with its client dependency fully composed."""
from core.plugin.impl.model_runtime import PluginModelRuntime
return PluginModelRuntime(
tenant_id=tenant_id,
user_id=user_id,
client=PluginModelClient(),
)
def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None = None) -> ModelProviderFactory:
"""Create a tenant-bound model provider factory for service flows."""
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
return ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id))
def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager:
"""Create a tenant-bound provider manager for service flows."""
from core.provider_manager import ProviderManager
return ProviderManager(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id))
def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager:
"""Create a tenant-bound model manager for service flows."""
from core.model_manager import ModelManager
return ModelManager(
provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id),
)

View File

@@ -1,50 +1,7 @@
from typing import Literal
from dify_graph.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from pydantic import BaseModel
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
class ChatModelMessage(BaseModel):
"""
Chat Message.
"""
text: str
role: PromptMessageRole
edition_type: Literal["basic", "jinja2"] | None = None
class CompletionModelPromptTemplate(BaseModel):
"""
Completion Model Prompt Template.
"""
text: str
edition_type: Literal["basic", "jinja2"] | None = None
class MemoryConfig(BaseModel):
"""
Memory Config.
"""
class RolePrefix(BaseModel):
"""
Role Prefix.
"""
user: str
assistant: str
class WindowConfig(BaseModel):
"""
Window Config.
"""
enabled: bool
size: int | None = None
role_prefix: RolePrefix | None = None
window: WindowConfig
query_prompt_template: str | None = None
__all__ = [
"ChatModelMessage",
"CompletionModelPromptTemplate",
"MemoryConfig",
]

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
import contextlib
import json
from collections import defaultdict
from collections.abc import Sequence
from json import JSONDecodeError
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
@@ -53,15 +55,22 @@ from models.provider import (
from models.provider_ids import ModelProviderID
from services.feature_service import FeatureService
if TYPE_CHECKING:
from dify_graph.model_runtime.runtime import ModelRuntime
class ProviderManager:
"""
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
ProviderManager manages tenant-scoped model provider configuration.
The runtime adapter is injected by the composition layer so this class stays
focused on configuration assembly instead of constructing plugin runtimes.
"""
def __init__(self):
def __init__(self, model_runtime: ModelRuntime):
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
self._model_runtime = model_runtime
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
"""
@@ -127,7 +136,7 @@ class ProviderManager:
)
# Get all provider entities
model_provider_factory = ModelProviderFactory(tenant_id)
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
provider_entities = model_provider_factory.get_providers()
# Get All preferred provider types of the workspace
@@ -321,7 +330,7 @@ class ProviderManager:
if not default_model:
return None
model_provider_factory = ModelProviderFactory(tenant_id)
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name)
return DefaultModelEntity(

View File

@@ -106,9 +106,9 @@ class DataPostProcessor:
) -> ModelInstance | None:
if reranking_model:
try:
model_manager = ModelManager()
reranking_provider_name = reranking_model["reranking_provider_name"]
reranking_model_name = reranking_model["reranking_model_name"]
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
reranking_provider_name = reranking_model.get("reranking_provider_name")
reranking_model_name = reranking_model.get("reranking_model_name")
if not reranking_provider_name or not reranking_model_name:
return None
rerank_model_instance = model_manager.get_model_instance(

View File

@@ -328,7 +328,7 @@ class RetrievalService:
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
if dataset.is_multimodal:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id)
is_support_vision = model_manager.check_model_support_vision(
tenant_id=dataset.tenant_id,
provider=reranking_model["reranking_provider_name"],

View File

@@ -303,7 +303,7 @@ class Vector:
redis_client.delete(collection_exist_cache_key)
def _get_embeddings(self) -> Embeddings:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,

View File

@@ -72,7 +72,7 @@ class DatasetDocumentStore:
max_position = 0
embedding_model = None
if self._dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,

View File

@@ -8,7 +8,7 @@ from sqlalchemy.exc import IntegrityError
from configs import dify_config
from core.entities.embedding_type import EmbeddingInputType
from core.model_manager import ModelInstance
from core.model_manager import ModelInstance, ModelManager
from core.rag.embedding.embedding_base import Embeddings
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
@@ -22,8 +22,20 @@ logger = logging.getLogger(__name__)
class CacheEmbedding(Embeddings):
def __init__(self, model_instance: ModelInstance, user: str | None = None):
self._model_instance = model_instance
self._user = user
self._model_instance = self._bind_model_instance(model_instance, user)
@staticmethod
def _bind_model_instance(model_instance: ModelInstance, user: str | None) -> ModelInstance:
if user is None:
return model_instance
tenant_id = model_instance.provider_model_bundle.configuration.tenant_id
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user).get_model_instance(
tenant_id=tenant_id,
provider=model_instance.provider,
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
)
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs in batches of 10."""
@@ -65,7 +77,7 @@ class CacheEmbedding(Embeddings):
batch_texts = embedding_queue_texts[i : i + max_chunks]
embedding_result = self._model_instance.invoke_text_embedding(
texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
texts=batch_texts, input_type=EmbeddingInputType.DOCUMENT
)
for vector in embedding_result.embeddings:
@@ -147,7 +159,6 @@ class CacheEmbedding(Embeddings):
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=batch_multimodel_documents,
user=self._user,
input_type=EmbeddingInputType.DOCUMENT,
)
@@ -202,7 +213,7 @@ class CacheEmbedding(Embeddings):
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_text_embedding(
texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
texts=[text], input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]
@@ -245,7 +256,7 @@ class CacheEmbedding(Embeddings):
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY
multimodel_documents=[multimodel_document], input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]

View File

@@ -12,7 +12,7 @@ from core.app.llm import deduct_llm_quota
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.keyword.keyword_factory import Keyword
@@ -410,7 +410,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
# If default prompt doesn't have {language} placeholder, use it as-is
pass
provider_manager = ProviderManager()
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id, model_provider_name, ModelType.LLM
)

View File

@@ -16,6 +16,18 @@ class RerankModelRunner(BaseRerankRunner):
def __init__(self, rerank_model_instance: ModelInstance):
self.rerank_model_instance = rerank_model_instance
def _get_invoke_model_instance(self, user: str | None) -> ModelInstance:
if user is None:
return self.rerank_model_instance
tenant_id = self.rerank_model_instance.provider_model_bundle.configuration.tenant_id
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user).get_model_instance(
tenant_id=tenant_id,
provider=self.rerank_model_instance.provider,
model_type=ModelType.RERANK,
model=self.rerank_model_instance.model_name,
)
def run(
self,
query: str,
@@ -34,7 +46,9 @@ class RerankModelRunner(BaseRerankRunner):
:param user: unique user id if needed
:return:
"""
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id
)
is_support_vision = model_manager.check_model_support_vision(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
provider=self.rerank_model_instance.provider,
@@ -102,8 +116,8 @@ class RerankModelRunner(BaseRerankRunner):
docs.append(document.page_content)
unique_documents.append(document)
rerank_result = self.rerank_model_instance.invoke_rerank(
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
rerank_result = self._get_invoke_model_instance(user).invoke_rerank(
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n
)
return rerank_result, unique_documents
@@ -180,8 +194,8 @@ class RerankModelRunner(BaseRerankRunner):
"content": file_query,
"content_type": DocType.IMAGE,
}
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
rerank_result = self._get_invoke_model_instance(user).invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
)
return rerank_result, unique_documents
else:

View File

@@ -163,7 +163,7 @@ class WeightRerankRunner(BaseRerankRunner):
"""
query_vector_scores = []
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=tenant_id,

View File

@@ -160,7 +160,7 @@ class DatasetRetrieval:
if request.model_provider is None or request.model_name is None or request.query is None:
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=request.tenant_id,
model_type=ModelType.LLM,
@@ -387,7 +387,7 @@ class DatasetRetrieval:
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
)
@@ -1411,7 +1411,7 @@ class DatasetRetrieval:
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config, user_id=user_id)
# fetch prompt messages
prompt_messages, stop = self._get_prompt_template(
@@ -1430,7 +1430,6 @@ class DatasetRetrieval:
model_parameters=model_config.parameters,
stop=stop,
stream=True,
user=user_id,
),
)
@@ -1533,7 +1532,7 @@ class DatasetRetrieval:
return filters
def _fetch_model_config(
self, tenant_id: str, model: ModelConfig
self, tenant_id: str, model: ModelConfig, user_id: str | None = None
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
@@ -1543,7 +1542,7 @@ class DatasetRetrieval:
model_name = model.name
provider_name = model.provider
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)

View File

@@ -3,13 +3,14 @@ from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.llm import deduct_llm_quota
from core.model_manager import ModelInstance
from core.model_manager import ModelInstance, ModelManager
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import ModelType
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
@@ -150,19 +151,24 @@ class ReactMultiDatasetRouter:
:param stop: stop
:return:
"""
invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm(
bound_model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
tenant_id=tenant_id,
provider=model_instance.provider,
model_type=ModelType.LLM,
model=model_instance.model_name,
)
invoke_result: Generator[LLMResult, None, None] = bound_model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=completion_param,
stop=stop,
stream=True,
user=user_id,
)
# handle invoke result
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
# deduct quota
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
deduct_llm_quota(tenant_id=tenant_id, model_instance=bound_model_instance, usage=usage)
return text, usage

View File

@@ -201,8 +201,10 @@ class HumanInputFormRepositoryImpl:
self,
*,
tenant_id: str,
app_id: str | None = None,
):
self._tenant_id = tenant_id
self._app_id = app_id
def _delivery_method_to_model(
self,
@@ -340,6 +342,9 @@ class HumanInputFormRepositoryImpl:
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
form_config: HumanInputNodeData = params.form_config
app_id = params.app_id or self._app_id
if not app_id:
raise ValueError("app_id is required to create a human input form")
with session_factory.create_session() as session, session.begin():
# Generate unique form ID
@@ -359,7 +364,7 @@ class HumanInputFormRepositoryImpl:
form_model = HumanInputForm(
id=form_id,
tenant_id=self._tenant_id,
app_id=params.app_id,
app_id=app_id,
workflow_run_id=params.workflow_execution_id,
form_kind=params.form_kind,
node_id=params.node_id,

View File

@@ -22,6 +22,9 @@ class ASRTool(BuiltinTool):
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
if not self.runtime:
raise ValueError("Runtime is required")
runtime = self.runtime
file = tool_parameters.get("audio_file")
if file.type != FileType.AUDIO: # type: ignore
yield self.create_text_message("not a valid audio file")
@@ -29,20 +32,19 @@ class ASRTool(BuiltinTool):
audio_binary = io.BytesIO(download(file)) # type: ignore
audio_binary.name = "temp.mp3"
provider, model = tool_parameters.get("model").split("#") # type: ignore
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id)
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id,
tenant_id=runtime.tenant_id,
provider=provider,
model_type=ModelType.SPEECH2TEXT,
model=model,
)
text = model_instance.invoke_speech2text(
file=audio_binary,
user=user_id,
)
text = model_instance.invoke_speech2text(file=audio_binary)
yield self.create_text_message(text)
def get_available_models(self) -> list[tuple[str, str]]:
if not self.runtime:
raise ValueError("Runtime is required")
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(
tenant_id=self.runtime.tenant_id, model_type="speech2text"

View File

@@ -20,13 +20,14 @@ class TTSTool(BuiltinTool):
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
provider, model = tool_parameters.get("model").split("#") # type: ignore
voice = tool_parameters.get(f"voice#{provider}#{model}")
model_manager = ModelManager()
if not self.runtime:
raise ValueError("Runtime is required")
runtime = self.runtime
provider, model = tool_parameters.get("model").split("#") # type: ignore
voice = tool_parameters.get(f"voice#{provider}#{model}")
model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id)
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id or "",
tenant_id=runtime.tenant_id or "",
provider=provider,
model_type=ModelType.TTS,
model=model,
@@ -39,12 +40,7 @@ class TTSTool(BuiltinTool):
raise ValueError("Sorry, no voice available.")
else:
raise ValueError("Sorry, no voice available.")
tts = model_instance.invoke_tts(
content_text=tool_parameters.get("text"), # type: ignore
user=user_id,
tenant_id=self.runtime.tenant_id,
voice=voice,
)
tts = model_instance.invoke_tts(content_text=tool_parameters.get("text"), voice=voice) # type: ignore[arg-type]
buffer = io.BytesIO()
for chunk in tts:
buffer.write(chunk)

View File

@@ -65,7 +65,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
for thread in threads:
thread.join()
# do rerank for searched documents
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
rerank_model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=self.reranking_provider_name,

View File

@@ -37,7 +37,7 @@ class ModelInvocationUtils:
"""
get max llm context tokens of the model
"""
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -65,7 +65,7 @@ class ModelInvocationUtils:
"""
# get model instance
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
if not model_instance:
@@ -92,7 +92,7 @@ class ModelInvocationUtils:
"""
# get model manager
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
# get model instance
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
@@ -136,7 +136,6 @@ class ModelInvocationUtils:
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
)
except InvokeRateLimitError as e:

View File

@@ -9,8 +9,8 @@ from sqlalchemy.orm import Session
from typing_extensions import override
from configs import dify_config
from core.app.entities.app_invoke_entities import DifyRunContext
from core.app.llm.model_access import build_dify_model_access
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.llm.model_access import build_dify_model_access, fetch_model_config
from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
@@ -20,8 +20,17 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.tools.tool_file_manager import ToolFileManager
from core.trigger.constants import TRIGGER_NODE_TYPES
from core.workflow.node_runtime import (
DifyFileReferenceFactory,
DifyHumanInputNodeRuntime,
DifyPreparedLLM,
DifyPromptMessageSerializer,
DifyRetrieverAttachmentLoader,
DifyToolFileManager,
DifyToolNodeRuntime,
build_dify_llm_file_saver,
)
from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer
from core.workflow.nodes.agent.plugin_strategy_adapter import (
PluginAgentStrategyPresentationProvider,
@@ -30,11 +39,9 @@ from core.workflow.nodes.agent.plugin_strategy_adapter import (
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey
from dify_graph.file.file_manager import file_manager
from dify_graph.graph.graph import NodeFactory
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.nodes.base.node import Node
@@ -44,13 +51,10 @@ from dify_graph.nodes.code.limits import CodeNodeLimits
from dify_graph.nodes.document_extractor import UnstructuredApiConfig
from dify_graph.nodes.http_request import build_http_request_config
from dify_graph.nodes.llm.entities import LLMNodeData
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from dify_graph.nodes.llm.protocols import TemplateRenderer
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
from dify_graph.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
)
from dify_graph.template_rendering import CodeExecutorJinja2TemplateRenderer
from dify_graph.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
@@ -264,11 +268,22 @@ class DifyNodeFactory(NodeFactory):
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor)
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor)
self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
self._http_request_http_client = ssrf_proxy
self._http_request_tool_file_manager_factory = ToolFileManager
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(self._dify_context)
self._file_reference_factory = DifyFileReferenceFactory(self._dify_context)
self._prompt_message_serializer = DifyPromptMessageSerializer()
self._retriever_attachment_loader = DifyRetrieverAttachmentLoader(
file_reference_factory=self._file_reference_factory,
)
self._llm_file_saver = build_dify_llm_file_saver(
run_context=self._dify_context,
http_client=self._http_request_http_client,
)
self._human_input_runtime = DifyHumanInputNodeRuntime(self._dify_context)
self._tool_runtime = DifyToolNodeRuntime(self._dify_context)
self._http_request_file_manager = file_manager
self._document_extractor_unstructured_api_config = UnstructuredApiConfig(
api_url=dify_config.UNSTRUCTURED_API_URL,
@@ -284,7 +299,7 @@ class DifyNodeFactory(NodeFactory):
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
)
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id)
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context)
self._agent_strategy_resolver = PluginAgentStrategyResolver()
self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider()
self._agent_runtime_support = AgentRuntimeSupport()
@@ -321,22 +336,32 @@ class DifyNodeFactory(NodeFactory):
"code_limits": self._code_limits,
},
BuiltinNodeTypes.TEMPLATE_TRANSFORM: lambda: {
"template_renderer": self._template_renderer,
"jinja2_template_renderer": self._jinja2_template_renderer,
"max_output_length": self._template_transform_max_output_length,
},
BuiltinNodeTypes.HTTP_REQUEST: lambda: {
"http_request_config": self._http_request_config,
"http_client": self._http_request_http_client,
"tool_file_manager_factory": self._http_request_tool_file_manager_factory,
"tool_file_manager_factory": self._bound_tool_file_manager_factory,
"file_manager": self._http_request_file_manager,
"file_reference_factory": self._file_reference_factory,
},
BuiltinNodeTypes.HUMAN_INPUT: lambda: {
"form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
"form_repository": HumanInputFormRepositoryImpl(
tenant_id=self._dify_context.tenant_id,
app_id=self._dify_context.app_id,
),
"runtime": self._human_input_runtime,
},
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
include_prompt_message_serializer=True,
include_retriever_attachment_loader=True,
include_jinja2_template_renderer=True,
),
BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: {
"unstructured_api_config": self._document_extractor_unstructured_api_config,
@@ -345,15 +370,26 @@ class DifyNodeFactory(NodeFactory):
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
include_prompt_message_serializer=True,
include_retriever_attachment_loader=False,
include_jinja2_template_renderer=False,
),
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
wrap_model_instance=True,
include_http_client=False,
include_llm_file_saver=False,
include_prompt_message_serializer=True,
include_retriever_attachment_loader=False,
include_jinja2_template_renderer=False,
),
BuiltinNodeTypes.TOOL: lambda: {
"tool_file_manager_factory": self._http_request_tool_file_manager_factory(),
"tool_file_manager_factory": self._bound_tool_file_manager_factory(),
"runtime": self._tool_runtime,
},
BuiltinNodeTypes.AGENT: lambda: {
"strategy_resolver": self._agent_strategy_resolver,
@@ -387,7 +423,12 @@ class DifyNodeFactory(NodeFactory):
*,
node_class: type[Node],
node_data: BaseNodeData,
wrap_model_instance: bool,
include_http_client: bool,
include_llm_file_saver: bool,
include_prompt_message_serializer: bool,
include_retriever_attachment_loader: bool,
include_jinja2_template_renderer: bool,
) -> dict[str, object]:
validated_node_data = cast(
LLMCompatibleNodeData,
@@ -397,49 +438,33 @@ class DifyNodeFactory(NodeFactory):
node_init_kwargs: dict[str, object] = {
"credentials_provider": self._llm_credentials_provider,
"model_factory": self._llm_model_factory,
"model_instance": model_instance,
"model_instance": DifyPreparedLLM(model_instance) if wrap_model_instance else model_instance,
"memory": self._build_memory_for_llm_node(
node_data=validated_node_data,
model_instance=model_instance,
),
}
if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
if validated_node_data.type == BuiltinNodeTypes.QUESTION_CLASSIFIER:
node_init_kwargs["template_renderer"] = self._llm_template_renderer
if include_http_client:
node_init_kwargs["http_client"] = self._http_request_http_client
if include_llm_file_saver:
node_init_kwargs["llm_file_saver"] = self._llm_file_saver
if include_prompt_message_serializer:
node_init_kwargs["prompt_message_serializer"] = self._prompt_message_serializer
if include_retriever_attachment_loader:
node_init_kwargs["retriever_attachment_loader"] = self._retriever_attachment_loader
if include_jinja2_template_renderer:
node_init_kwargs["jinja2_template_renderer"] = self._jinja2_template_renderer
return node_init_kwargs
def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance:
node_data_model = node_data.model
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name)
model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
provider_model_bundle = model_instance.provider_model_bundle
provider_model = provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name,
model_type=ModelType.LLM,
model_instance, _ = fetch_model_config(
node_data_model=node_data_model,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
completion_params = dict(node_data_model.completion_params)
stop = completion_params.pop("stop", [])
if not isinstance(stop, list):
stop = []
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
model_instance.provider = node_data_model.provider
model_instance.model_name = node_data_model.name
model_instance.credentials = credentials
model_instance.parameters = completion_params
model_instance.stop = tuple(stop)
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
return model_instance

View File

@@ -0,0 +1,532 @@
from __future__ import annotations
from collections.abc import Generator, Mapping, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelInstance
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
from core.plugin.impl.plugin import PluginInstaller
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType
from core.tools.errors import ToolInvokeError
from core.tools.signature import sign_upload_file
from core.tools.tool_engine import ToolEngine
from core.tools.tool_file_manager import ToolFileManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.file import FileTransferMethod, FileType
from dify_graph.model_runtime.entities import LLMMode
from dify_graph.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
LLMUsage,
)
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.nodes.human_input.entities import DeliveryChannelConfig, apply_debug_email_recipient
from dify_graph.nodes.llm.runtime_protocols import (
PreparedLLMProtocol,
PromptMessageSerializerProtocol,
RetrieverAttachmentLoaderProtocol,
)
from dify_graph.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol
from dify_graph.nodes.runtime import (
HumanInputNodeRuntimeProtocol,
ToolNodeRuntimeProtocol,
)
from dify_graph.nodes.tool.exc import ToolNodeError, ToolRuntimeInvocationError, ToolRuntimeResolutionError
from dify_graph.nodes.tool_runtime_entities import (
ToolRuntimeHandle,
ToolRuntimeMessage,
ToolRuntimeParameter,
)
from extensions.ext_database import db
from factories import file_factory
from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage
from dify_graph.file import File
from dify_graph.nodes.llm.file_saver import LLMFileSaver
from dify_graph.nodes.tool.entities import ToolNodeData
def resolve_dify_run_context(run_context: Mapping[str, Any] | DifyRunContext) -> DifyRunContext:
if isinstance(run_context, DifyRunContext):
return run_context
raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY)
if raw_ctx is None:
raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}")
if isinstance(raw_ctx, DifyRunContext):
return raw_ctx
return DifyRunContext.model_validate(raw_ctx)
class DifyFileReferenceFactory(FileReferenceFactoryProtocol):
def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None:
self._run_context = resolve_dify_run_context(run_context)
def build_from_mapping(self, *, mapping: Mapping[str, Any]):
return file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self._run_context.tenant_id,
)
class DifyPreparedLLM(PreparedLLMProtocol):
"""Workflow-layer adapter that hides the full `ModelInstance` API from `dify_graph` nodes."""
def __init__(self, model_instance: ModelInstance) -> None:
self._model_instance = model_instance
@property
def provider(self) -> str:
return self._model_instance.provider
@property
def model_name(self) -> str:
return self._model_instance.model_name
@property
def parameters(self) -> Mapping[str, Any]:
return self._model_instance.parameters
@property
def stop(self) -> Sequence[str] | None:
return self._model_instance.stop
def get_model_schema(self) -> AIModelEntity:
model_schema = cast(LargeLanguageModel, self._model_instance.model_type_instance).get_model_schema(
self._model_instance.model_name,
self._model_instance.credentials,
)
if model_schema is None:
raise ValueError(f"Model schema not found for {self._model_instance.model_name}")
return model_schema
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
return self._model_instance.get_llm_num_tokens(prompt_messages)
def invoke_llm(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: Mapping[str, Any],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: bool,
) -> LLMResult | Generator[LLMResultChunk, None, None]:
return self._model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters=dict(model_parameters),
tools=list(tools or []),
stop=list(stop or []),
stream=stream,
)
def invoke_llm_with_structured_output(
self,
*,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping[str, Any],
stop: Sequence[str] | None,
stream: bool,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
return invoke_llm_with_structured_output(
provider=self.provider,
model_schema=self.get_model_schema(),
model_instance=self._model_instance,
prompt_messages=prompt_messages,
json_schema=json_schema,
model_parameters=model_parameters,
stop=list(stop or []),
stream=stream,
)
def is_structured_output_parse_error(self, error: Exception) -> bool:
return isinstance(error, OutputParserError)
class DifyPromptMessageSerializer(PromptMessageSerializerProtocol):
def serialize(
self,
*,
model_mode: LLMMode,
prompt_messages: Sequence[PromptMessage],
) -> Any:
return PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_mode,
prompt_messages=prompt_messages,
)
class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol):
"""Resolve retriever attachments through Dify persistence and return graph file references."""
def __init__(self, *, file_reference_factory: FileReferenceFactoryProtocol) -> None:
self._file_reference_factory = file_reference_factory
def load(self, *, segment_id: str) -> Sequence[File]:
with Session(db.engine, expire_on_commit=False) as session:
attachments_with_bindings = session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(SegmentAttachmentBinding.segment_id == segment_id)
).all()
return [
self._file_reference_factory.build_from_mapping(
mapping={
"id": upload_file.id,
"filename": upload_file.name,
"extension": "." + upload_file.extension,
"mime_type": upload_file.mime_type,
"type": FileType.IMAGE,
"transfer_method": FileTransferMethod.LOCAL_FILE,
"remote_url": upload_file.source_url,
"related_id": upload_file.id,
"upload_file_id": upload_file.id,
"size": upload_file.size,
"storage_key": upload_file.key,
"url": sign_upload_file(upload_file.id, upload_file.extension),
}
)
for _, upload_file in attachments_with_bindings
]
class DifyToolFileManager(ToolFileManagerProtocol):
def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None:
self._run_context = resolve_dify_run_context(run_context)
self._manager = ToolFileManager()
def create_file_by_raw(
self,
*,
conversation_id: str | None,
file_binary: bytes,
mimetype: str,
filename: str | None = None,
) -> Any:
return self._manager.create_file_by_raw(
user_id=self._run_context.user_id,
tenant_id=self._run_context.tenant_id,
conversation_id=conversation_id,
file_binary=file_binary,
mimetype=mimetype,
filename=filename,
)
def get_file_generator_by_tool_file_id(self, tool_file_id: str):
return self._manager.get_file_generator_by_tool_file_id(tool_file_id)
@dataclass(frozen=True, slots=True)
class _WorkflowToolRuntimeSpec:
provider_type: CoreToolProviderType
provider_id: str
tool_name: str
tool_configurations: dict[str, Any]
credential_id: str | None = None
class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None:
self._run_context = resolve_dify_run_context(run_context)
self._file_reference_factory = DifyFileReferenceFactory(self._run_context)
@property
def file_reference_factory(self) -> FileReferenceFactoryProtocol:
return self._file_reference_factory
def build_file_reference(self, *, mapping: Mapping[str, Any]):
return self._file_reference_factory.build_from_mapping(mapping=mapping)
def get_runtime(
self,
*,
node_id: str,
node_data: ToolNodeData,
variable_pool,
) -> ToolRuntimeHandle:
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(
self._run_context.tenant_id,
self._run_context.app_id,
node_id,
self._build_tool_runtime_spec(node_data),
self._run_context.invoke_from,
variable_pool,
)
except ToolNodeError:
raise
except Exception as exc:
raise ToolRuntimeResolutionError(str(exc)) from exc
return ToolRuntimeHandle(raw=tool_runtime)
def get_runtime_parameters(
self,
*,
tool_runtime: ToolRuntimeHandle,
) -> Sequence[ToolRuntimeParameter]:
tool = self._tool_from_handle(tool_runtime)
return [
ToolRuntimeParameter(name=parameter.name, required=parameter.required)
for parameter in (tool.get_merged_runtime_parameters() or [])
]
def invoke(
self,
*,
tool_runtime: ToolRuntimeHandle,
tool_parameters: Mapping[str, Any],
workflow_call_depth: int,
conversation_id: str | None,
provider_name: str,
) -> Generator[ToolRuntimeMessage, None, None]:
tool = self._tool_from_handle(tool_runtime)
callback = DifyWorkflowCallbackHandler()
try:
messages = ToolEngine.generic_invoke(
tool=tool,
tool_parameters=dict(tool_parameters),
user_id=self._run_context.user_id,
workflow_tool_callback=callback,
workflow_call_depth=workflow_call_depth,
app_id=self._run_context.app_id,
conversation_id=conversation_id,
)
except Exception as exc:
raise self._map_invocation_exception(exc, provider_name=provider_name) from exc
transformed_messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=self._run_context.user_id,
tenant_id=self._run_context.tenant_id,
conversation_id=None,
)
return self._adapt_messages(transformed_messages, provider_name=provider_name)
def get_usage(
self,
*,
tool_runtime: ToolRuntimeHandle,
) -> LLMUsage:
latest = getattr(tool_runtime.raw, "latest_usage", None)
if isinstance(latest, LLMUsage):
return latest
if isinstance(latest, dict):
return LLMUsage.model_validate(latest)
return LLMUsage.empty_usage()
def resolve_provider_icons(
self,
*,
provider_name: str,
default_icon: str | None = None,
) -> tuple[str | None, str | None]:
icon = default_icon
icon_dark = None
manager = PluginInstaller()
plugins = manager.list_plugins(self._run_context.tenant_id)
try:
current_plugin = next(plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == provider_name)
icon = current_plugin.declaration.icon
except StopIteration:
pass
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
self._run_context.user_id,
self._run_context.tenant_id,
)
if provider.name == provider_name
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
return icon, icon_dark
@staticmethod
def _tool_from_handle(tool_runtime: ToolRuntimeHandle) -> Tool:
return cast("Tool", tool_runtime.raw)
@staticmethod
def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec:
return _WorkflowToolRuntimeSpec(
provider_type=CoreToolProviderType(node_data.provider_type.value),
provider_id=node_data.provider_id,
tool_name=node_data.tool_name,
tool_configurations=dict(node_data.tool_configurations),
credential_id=node_data.credential_id,
)
def _adapt_messages(
self,
messages: Generator[CoreToolInvokeMessage, None, None],
*,
provider_name: str,
) -> Generator[ToolRuntimeMessage, None, None]:
try:
for message in messages:
yield self._convert_message(message)
except Exception as exc:
raise self._map_invocation_exception(exc, provider_name=provider_name) from exc
def _convert_message(self, message: CoreToolInvokeMessage) -> ToolRuntimeMessage:
graph_message_type = ToolRuntimeMessage.MessageType(message.type.value)
graph_message = self._convert_message_payload(message.message)
graph_meta = message.meta.copy() if message.meta is not None else None
return ToolRuntimeMessage(type=graph_message_type, message=graph_message, meta=graph_meta)
def _convert_message_payload(
self,
message: CoreToolInvokeMessage.TextMessage
| CoreToolInvokeMessage.JsonMessage
| CoreToolInvokeMessage.BlobChunkMessage
| CoreToolInvokeMessage.BlobMessage
| CoreToolInvokeMessage.LogMessage
| CoreToolInvokeMessage.FileMessage
| CoreToolInvokeMessage.VariableMessage
| CoreToolInvokeMessage.RetrieverResourceMessage
| None,
) -> (
ToolRuntimeMessage.TextMessage
| ToolRuntimeMessage.JsonMessage
| ToolRuntimeMessage.BlobChunkMessage
| ToolRuntimeMessage.BlobMessage
| ToolRuntimeMessage.LogMessage
| ToolRuntimeMessage.FileMessage
| ToolRuntimeMessage.VariableMessage
| ToolRuntimeMessage.RetrieverResourceMessage
| None
):
if message is None:
return None
from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage
if isinstance(message, CoreToolInvokeMessage.TextMessage):
return ToolRuntimeMessage.TextMessage(text=message.text)
if isinstance(message, CoreToolInvokeMessage.JsonMessage):
return ToolRuntimeMessage.JsonMessage(
json_object=message.json_object,
suppress_output=message.suppress_output,
)
if isinstance(message, CoreToolInvokeMessage.BlobMessage):
return ToolRuntimeMessage.BlobMessage(blob=message.blob)
if isinstance(message, CoreToolInvokeMessage.BlobChunkMessage):
return ToolRuntimeMessage.BlobChunkMessage(
id=message.id,
sequence=message.sequence,
total_length=message.total_length,
blob=message.blob,
end=message.end,
)
if isinstance(message, CoreToolInvokeMessage.FileMessage):
return ToolRuntimeMessage.FileMessage(file_marker=message.file_marker)
if isinstance(message, CoreToolInvokeMessage.VariableMessage):
return ToolRuntimeMessage.VariableMessage(
variable_name=message.variable_name,
variable_value=message.variable_value,
stream=message.stream,
)
if isinstance(message, CoreToolInvokeMessage.LogMessage):
return ToolRuntimeMessage.LogMessage(
id=message.id,
label=message.label,
parent_id=message.parent_id,
error=message.error,
status=ToolRuntimeMessage.LogMessage.LogStatus(message.status.value),
data=dict(message.data),
metadata=dict(message.metadata),
)
if isinstance(message, CoreToolInvokeMessage.RetrieverResourceMessage):
retriever_resources = [
resource.model_dump() if hasattr(resource, "model_dump") else dict(resource)
for resource in message.retriever_resources
]
return ToolRuntimeMessage.RetrieverResourceMessage(
retriever_resources=retriever_resources,
context=message.context,
)
raise TypeError(f"unsupported tool message payload: {type(message).__name__}")
@staticmethod
def _map_invocation_exception(exc: Exception, *, provider_name: str) -> ToolNodeError:
if isinstance(exc, ToolNodeError):
return exc
if isinstance(exc, PluginInvokeError):
return ToolRuntimeInvocationError(exc.to_user_friendly_error(plugin_name=provider_name))
if isinstance(exc, PluginDaemonClientSideError):
return ToolRuntimeInvocationError(f"Failed to invoke tool, error: {exc.description}")
if isinstance(exc, ToolInvokeError):
return ToolRuntimeInvocationError(f"Failed to invoke tool {provider_name}: {exc}")
return ToolRuntimeInvocationError(str(exc))
class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol):
def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None:
self._run_context = resolve_dify_run_context(run_context)
def invoke_source(self) -> str:
invoke_from = self._run_context.invoke_from
if isinstance(invoke_from, str):
return invoke_from
return str(getattr(invoke_from, "value", invoke_from))
def apply_delivery_runtime(
self,
*,
methods: Sequence[DeliveryChannelConfig],
) -> Sequence[DeliveryChannelConfig]:
return [
apply_debug_email_recipient(
method,
enabled=self.invoke_source() == "debugger",
user_id=self._run_context.user_id,
)
for method in methods
]
def console_actor_id(self) -> str | None:
return self._run_context.user_id
def build_dify_llm_file_saver(
*,
run_context: Mapping[str, Any] | DifyRunContext,
http_client: HttpClientProtocol,
) -> LLMFileSaver:
from dify_graph.nodes.llm.file_saver import FileSaverImpl
return FileSaverImpl(
tool_file_manager=DifyToolFileManager(run_context),
file_reference_factory=DifyFileReferenceFactory(run_context),
http_client=http_client,
)

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
@@ -59,7 +60,7 @@ class AgentNode(Node[AgentNodeData]):
return "1"
def populate_start_event(self, event) -> None:
dify_ctx = self.require_dify_context()
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
event.extras["agent_strategy"] = {
"name": self.node_data.agent_strategy_name,
"icon": self._presentation_provider.get_icon(
@@ -71,7 +72,7 @@ class AgentNode(Node[AgentNodeData]):
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
dify_ctx = self.require_dify_context()
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
try:
strategy = self._strategy_resolver.resolve(
@@ -97,6 +98,7 @@ class AgentNode(Node[AgentNodeData]):
node_data=self.node_data,
strategy=strategy,
tenant_id=dify_ctx.tenant_id,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
invoke_from=dify_ctx.invoke_from,
)
@@ -106,6 +108,7 @@ class AgentNode(Node[AgentNodeData]):
node_data=self.node_data,
strategy=strategy,
tenant_id=dify_ctx.tenant_id,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
invoke_from=dify_ctx.invoke_from,
for_log=True,

View File

@@ -14,7 +14,7 @@ from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.plugin.entities.request import InvokeCredentials
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager
from dify_graph.enums import SystemVariableKey
@@ -38,6 +38,7 @@ class AgentRuntimeSupport:
node_data: AgentNodeData,
strategy: ResolvedAgentStrategy,
tenant_id: str,
user_id: str,
app_id: str,
invoke_from: Any,
for_log: bool = False,
@@ -174,7 +175,11 @@ class AgentRuntimeSupport:
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
model_instance, model_schema = self.fetch_model(
tenant_id=tenant_id,
user_id=user_id,
value=value,
)
history_prompt_messages = []
if node_data.memory:
memory = self.fetch_memory(
@@ -232,8 +237,14 @@ class AgentRuntimeSupport:
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
def fetch_model(
self,
*,
tenant_id: str,
user_id: str,
value: dict[str, Any],
) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id)
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=value.get("provider", ""),
@@ -246,7 +257,7 @@ class AgentRuntimeSupport:
)
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
tenant_id=tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),

View File

@@ -1,6 +1,7 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.datasource.datasource_manager import DatasourceManager
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
@@ -50,8 +51,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
"""
Run the datasource node
"""
dify_ctx = self.require_dify_context()
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])

View File

@@ -9,6 +9,7 @@ from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from dify_graph.entities import GraphInitParams
@@ -160,7 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
) -> tuple[list[Source], LLMUsage]:
dify_ctx = self.require_dify_context()
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
dataset_ids = node_data.dataset_ids
query = variables.get("query")
attachments = variables.get("attachments")

View File

@@ -9,9 +9,9 @@ from dify_graph.enums import NodeExecutionType
from dify_graph.file import FileTransferMethod
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.protocols import FileReferenceFactoryProtocol
from dify_graph.variables.types import SegmentType
from dify_graph.variables.variables import FileVariable
from factories import file_factory
from factories.variable_factory import build_segment_with_type
from .entities import ContentType, WebhookData
@@ -23,6 +23,13 @@ class TriggerWebhookNode(Node[WebhookData]):
node_type = TRIGGER_WEBHOOK_NODE_TYPE
execution_type = NodeExecutionType.ROOT
_file_reference_factory: FileReferenceFactoryProtocol
def post_init(self) -> None:
from core.workflow.node_runtime import DifyFileReferenceFactory
self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context)
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
@@ -70,7 +77,6 @@ class TriggerWebhookNode(Node[WebhookData]):
)
def generate_file_var(self, param_name: str, file: dict):
dify_ctx = self.require_dify_context()
related_id = file.get("related_id")
transfer_method_value = file.get("transfer_method")
if transfer_method_value:
@@ -84,10 +90,7 @@ class TriggerWebhookNode(Node[WebhookData]):
file["datasource_file_id"] = related_id
try:
file_obj = file_factory.build_from_mapping(
mapping=file,
tenant_id=dify_ctx.tenant_id,
)
file_obj = self._file_reference_factory.build_from_mapping(mapping=file)
file_segment = build_segment_with_type(SegmentType.FILE, file_obj)
return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name])
except ValueError:

View File

@@ -3,8 +3,6 @@ from typing import Any
from pydantic import BaseModel, Field
DIFY_RUN_CONTEXT_KEY = "_dify"
class GraphInitParams(BaseModel):
"""GraphInitParams encapsulates the configurations and contextual information

View File

@@ -14,7 +14,7 @@ from typing import Any
from pydantic import BaseModel, Field
from dify_graph.enums import WorkflowExecutionStatus, WorkflowType
from libs.datetime_utils import naive_utc_now
from dify_graph.utils.datetime_utils import naive_utc_now
class WorkflowExecution(BaseModel):

View File

@@ -10,7 +10,6 @@ from pydantic import TypeAdapter
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState
from dify_graph.nodes.base.node import Node
from libs.typing import is_str
from .edge import Edge
from .validation import get_graph_validator
@@ -102,7 +101,7 @@ class Graph:
source = edge_config.get("source")
target = edge_config.get("target")
if not is_str(source) or not is_str(target):
if not isinstance(source, str) or not isinstance(target, str):
continue
# Create edge
@@ -110,7 +109,7 @@ class Graph:
edge_counter += 1
source_handle = edge_config.get("sourceHandle", "source")
if not is_str(source_handle):
if not isinstance(source_handle, str):
continue
edge = Edge(

View File

@@ -1,9 +1,9 @@
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any
from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities.pause_reason import PauseReason
from .base import GraphNodeEventBase
@@ -30,7 +30,7 @@ class NodeRunStreamChunkEvent(GraphNodeEventBase):
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources")
context: str = Field(..., description="context")

View File

@@ -93,10 +93,14 @@ class ModelCredentialSchema(BaseModel):
class SimpleProviderEntity(BaseModel):
"""
Simple model class for provider.
Simplified provider schema exposed to callers.
`provider` is the canonical runtime identifier. `provider_name` is an optional
compatibility alias for short-name lookups and is empty when no alias exists.
"""
provider: str
provider_name: str = ""
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
@@ -115,10 +119,15 @@ class ProviderHelpEntity(BaseModel):
class ProviderEntity(BaseModel):
"""
Model class for provider.
Runtime-native provider schema.
`provider` is the canonical runtime identifier. `provider_name` is a
compatibility alias for callers that still resolve providers by short name and
is empty when no alias exists.
"""
provider: str
provider_name: str = ""
label: I18nObject
description: I18nObject | None = None
icon_small: I18nObject | None = None
@@ -153,6 +162,7 @@ class ProviderEntity(BaseModel):
"""
return SimpleProviderEntity(
provider=self.provider,
provider_name=self.provider_name,
label=self.label,
icon_small=self.icon_small,
supported_model_types=self.supported_model_types,

View File

@@ -1,6 +1,13 @@
from typing import TypedDict
from pydantic import BaseModel
class MultimodalRerankInput(TypedDict):
content: str
content_type: str
class RerankDocument(BaseModel):
"""
Model class for rerank document.

View File

@@ -1,10 +1,18 @@
from decimal import Decimal
from enum import StrEnum, auto
from pydantic import BaseModel
from dify_graph.model_runtime.entities.model_entities import ModelUsage
class EmbeddingInputType(StrEnum):
"""Embedding request input variants understood by the model runtime."""
DOCUMENT = auto()
QUERY = auto()
class EmbeddingUsage(ModelUsage):
"""
Model class for embedding usage.

View File

@@ -1,12 +1,5 @@
import decimal
import hashlib
import logging
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from redis import RedisError
from configs import dify_config
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from dify_graph.model_runtime.entities.model_entities import (
@@ -17,6 +10,7 @@ from dify_graph.model_runtime.entities.model_entities import (
PriceInfo,
PriceType,
)
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
from dify_graph.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@@ -25,45 +19,61 @@ from dify_graph.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
from dify_graph.model_runtime.runtime import ModelRuntime
class AIModel(BaseModel):
class AIModel:
"""
Base class for all models.
Runtime-facing base class for all model providers.
This stays a regular Python class because instances hold live collaborators
such as the provider schema and runtime adapter rather than user input that
benefits from Pydantic validation. Subclasses must pin ``model_type`` via a
class attribute; the base class is not meant to be instantiated directly.
"""
tenant_id: str = Field(description="Tenant ID")
model_type: ModelType = Field(description="Model type")
plugin_id: str = Field(description="Plugin ID")
provider_name: str = Field(description="Provider")
plugin_model_provider: PluginModelProviderEntity = Field(description="Plugin model provider")
started_at: float = Field(description="Invoke start time", default=0)
model_type: ModelType
provider_schema: ProviderEntity
model_runtime: ModelRuntime
started_at: float
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def __init__(
self,
provider_schema: ProviderEntity,
model_runtime: ModelRuntime,
*,
started_at: float = 0,
) -> None:
if getattr(type(self), "model_type", None) is None:
raise TypeError("AIModel subclasses must define model_type as a class attribute")
self.model_type = type(self).model_type
self.provider_schema = provider_schema
self.model_runtime = model_runtime
self.started_at = started_at
@property
def provider(self) -> str:
return self.provider_schema.provider
@property
def provider_display_name(self) -> str:
return self.provider_schema.label.en_US
@property
def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
Map model invoke error to unified error.
:return: Invoke error mapping
The key is the error type thrown to the caller, and the value contains
runtime-facing exception types that should be normalized to it.
"""
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [InvokeBadRequestError],
PluginDaemonInnerError: [PluginDaemonInnerError],
ValueError: [ValueError],
}
@@ -79,15 +89,18 @@ class AIModel(BaseModel):
if invoke_error == InvokeAuthorizationError:
return InvokeAuthorizationError(
description=(
f"[{self.provider_name}] Incorrect model credentials provided, please check and try again."
f"[{self.provider_display_name}] Incorrect model credentials provided, "
"please check and try again."
)
)
elif isinstance(invoke_error, InvokeError):
return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}")
return InvokeError(
description=f"[{self.provider_display_name}] {invoke_error.description}, {str(error)}"
)
else:
return error
return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}")
return InvokeError(description=f"[{self.provider_display_name}] Error: {str(error)}")
def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo:
"""
@@ -144,65 +157,13 @@ class AIModel(BaseModel):
:param credentials: model credentials
:return: model schema
"""
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
cached_schema_json = None
try:
cached_schema_json = redis_client.get(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to read plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if cached_schema_json:
try:
return AIModelEntity.model_validate_json(cached_schema_json)
except ValidationError:
logger.warning(
"Failed to validate cached plugin model schema for model %s",
model,
exc_info=True,
)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to delete invalid plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
schema = plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
return self.model_runtime.get_model_schema(
provider=self.provider,
model_type=self.model_type,
model=model,
credentials=credentials or {},
)
if schema:
try:
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to write plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
return schema
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get customizable model schema from credentials

View File

@@ -4,9 +4,6 @@ import uuid
from collections.abc import Callable, Generator, Iterator, Sequence
from typing import Union
from pydantic import ConfigDict
from configs import dify_config
from dify_graph.model_runtime.callbacks.base_callback import Callback
from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
@@ -140,11 +137,9 @@ def _build_llm_result_from_chunks(
)
def _invoke_llm_via_plugin(
def _invoke_llm_via_runtime(
*,
tenant_id: str,
user_id: str,
plugin_id: str,
llm_model: "LargeLanguageModel",
provider: str,
model: str,
credentials: dict,
@@ -154,25 +149,19 @@ def _invoke_llm_via_plugin(
stop: Sequence[str] | None,
stream: bool,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_llm(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
return llm_model.model_runtime.invoke_llm(
provider=provider,
model=model,
credentials=credentials,
model_parameters=model_parameters,
prompt_messages=list(prompt_messages),
tools=tools,
stop=list(stop) if stop else None,
stop=stop,
stream=stream,
)
def _normalize_non_stream_plugin_result(
def _normalize_non_stream_runtime_result(
model: str,
prompt_messages: Sequence[PromptMessage],
result: Union[LLMResult, Iterator[LLMResultChunk]],
@@ -208,9 +197,6 @@ class LargeLanguageModel(AIModel):
model_type: ModelType = ModelType.LLM
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(
self,
model: str,
@@ -220,7 +206,6 @@ class LargeLanguageModel(AIModel):
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
"""
@@ -233,7 +218,6 @@ class LargeLanguageModel(AIModel):
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
@@ -245,7 +229,7 @@ class LargeLanguageModel(AIModel):
callbacks = callbacks or []
if dify_config.DEBUG:
if logger.isEnabledFor(logging.DEBUG):
callbacks.append(LoggingCallback())
# trigger before invoke callbacks
@@ -257,18 +241,15 @@ class LargeLanguageModel(AIModel):
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
try:
result = _invoke_llm_via_plugin(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
result = _invoke_llm_via_runtime(
llm_model=self,
provider=self.provider,
model=model,
credentials=credentials,
model_parameters=model_parameters,
@@ -279,7 +260,7 @@ class LargeLanguageModel(AIModel):
)
if not stream:
result = _normalize_non_stream_plugin_result(
result = _normalize_non_stream_runtime_result(
model=model, prompt_messages=prompt_messages, result=result
)
except Exception as e:
@@ -292,7 +273,6 @@ class LargeLanguageModel(AIModel):
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
@@ -309,7 +289,6 @@ class LargeLanguageModel(AIModel):
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
elif isinstance(result, LLMResult):
@@ -322,7 +301,6 @@ class LargeLanguageModel(AIModel):
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
# Following https://github.com/langgenius/dify/issues/17799,
@@ -435,22 +413,14 @@ class LargeLanguageModel(AIModel):
:param tools: tools for tool calling
:return:
"""
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_llm_num_tokens(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
tools=tools,
)
return 0
return self.model_runtime.get_llm_num_tokens(
provider=self.provider,
model_type=self.model_type,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
tools=tools,
)
def calc_response_usage(
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int

View File

@@ -1,7 +1,5 @@
import time
from pydantic import ConfigDict
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
@@ -13,30 +11,20 @@ class ModerationModel(AIModel):
model_type: ModelType = ModelType.MODERATION
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool:
def invoke(self, model: str, credentials: dict, text: str) -> bool:
"""
Invoke moderation model
:param model: model name
:param credentials: model credentials
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
self.started_at = time.perf_counter()
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_moderation(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.invoke_moderation(
provider=self.provider,
model=model,
credentials=credentials,
text=text,

View File

@@ -1,5 +1,5 @@
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
@@ -18,7 +18,6 @@ class RerankModel(AIModel):
docs: list[str],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model
@@ -29,18 +28,11 @@ class RerankModel(AIModel):
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_rerank(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.invoke_rerank(
provider=self.provider,
model=model,
credentials=credentials,
query=query,
@@ -55,11 +47,10 @@ class RerankModel(AIModel):
self,
model: str,
credentials: dict,
query: dict,
docs: list[dict],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke multimodal rerank model
@@ -69,18 +60,11 @@ class RerankModel(AIModel):
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_multimodal_rerank(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.invoke_multimodal_rerank(
provider=self.provider,
model=model,
credentials=credentials,
query=query,

View File

@@ -1,7 +1,5 @@
from typing import IO
from pydantic import ConfigDict
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
@@ -13,28 +11,18 @@ class Speech2TextModel(AIModel):
model_type: ModelType = ModelType.SPEECH2TEXT
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str:
def invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
"""
Invoke speech to text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_speech_to_text(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.invoke_speech_to_text(
provider=self.provider,
model=model,
credentials=credentials,
file=file,

View File

@@ -1,8 +1,5 @@
from pydantic import ConfigDict
from core.entities.embedding_type import EmbeddingInputType
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
@@ -13,16 +10,12 @@ class TextEmbeddingModel(AIModel):
model_type: ModelType = ModelType.TEXT_EMBEDDING
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(
self,
model: str,
credentials: dict,
texts: list[str] | None = None,
multimodel_documents: list[dict] | None = None,
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> EmbeddingResult:
"""
@@ -32,31 +25,21 @@ class TextEmbeddingModel(AIModel):
:param credentials: model credentials
:param texts: texts to embed
:param files: files to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
from core.plugin.impl.model import PluginModelClient
try:
plugin_model_manager = PluginModelClient()
if texts:
return plugin_model_manager.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.invoke_text_embedding(
provider=self.provider,
model=model,
credentials=credentials,
texts=texts,
input_type=input_type,
)
if multimodel_documents:
return plugin_model_manager.invoke_multimodal_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.invoke_multimodal_embedding(
provider=self.provider,
model=model,
credentials=credentials,
documents=multimodel_documents,
@@ -75,14 +58,8 @@ class TextEmbeddingModel(AIModel):
:param texts: texts to embed
:return:
"""
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_text_embedding_num_tokens(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.get_text_embedding_num_tokens(
provider=self.provider,
model=model,
credentials=credentials,
texts=texts,

View File

@@ -1,8 +1,6 @@
import logging
from collections.abc import Iterable
from pydantic import ConfigDict
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
@@ -16,38 +14,25 @@ class TTSModel(AIModel):
model_type: ModelType = ModelType.TTS
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
user: str | None = None,
) -> Iterable[bytes]:
"""
Invoke large language model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:param user: unique user id
:return: translated audio file
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_tts(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.invoke_tts(
provider=self.provider,
model=model,
credentials=credentials,
content_text=content_text,
@@ -65,14 +50,8 @@ class TTSModel(AIModel):
:param credentials: The credentials required to access the TTS model.
:return: A list of voices supported by the TTS model.
"""
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_tts_model_voices(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
return self.model_runtime.get_tts_model_voices(
provider=self.provider,
model=model,
credentials=credentials,
language=language,

View File

@@ -1,16 +1,7 @@
from __future__ import annotations
import hashlib
import logging
from collections.abc import Sequence
from threading import Lock
from pydantic import ValidationError
from redis import RedisError
import contexts
from configs import dify_config
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
@@ -20,120 +11,64 @@ from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankM
from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel
from dify_graph.model_runtime.runtime import ModelRuntime
from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import (
ProviderCredentialSchemaValidator,
)
from extensions.ext_redis import redis_client
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
class ModelProviderFactory:
def __init__(self, tenant_id: str):
from core.plugin.impl.model import PluginModelClient
"""Factory for provider schemas and model-type instances backed by a runtime adapter."""
self.tenant_id = tenant_id
self.plugin_model_manager = PluginModelClient()
def __init__(self, model_runtime: ModelRuntime):
if model_runtime is None:
raise ValueError("model_runtime is required.")
self.model_runtime = model_runtime
def get_providers(self) -> Sequence[ProviderEntity]:
"""
Get all providers
:return: list of providers
Get all providers.
"""
# FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server
# The plugin server should return providers in the desired order
plugin_providers = self.get_plugin_model_providers()
return [provider.declaration for provider in plugin_providers]
return list(self.get_model_providers())
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
def get_model_providers(self) -> Sequence[ProviderEntity]:
"""
Get all plugin model providers
:return: list of plugin model providers
Get all model providers exposed by the runtime adapter.
"""
# check if context is set
try:
contexts.plugin_model_providers.get()
except LookupError:
contexts.plugin_model_providers.set(None)
contexts.plugin_model_providers_lock.set(Lock())
with contexts.plugin_model_providers_lock.get():
plugin_model_providers = contexts.plugin_model_providers.get()
if plugin_model_providers is not None:
return plugin_model_providers
plugin_model_providers = []
contexts.plugin_model_providers.set(plugin_model_providers)
# Fetch plugin model providers
plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
for provider in plugin_providers:
provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
plugin_model_providers.append(provider)
return plugin_model_providers
return self.model_runtime.fetch_model_providers()
def get_provider_schema(self, provider: str) -> ProviderEntity:
"""
Get provider schema
:param provider: provider name
:return: provider schema
Get provider schema.
"""
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
return plugin_model_provider_entity.declaration
return self.get_model_provider(provider=provider)
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
def get_model_provider(self, provider: str) -> ProviderEntity:
"""
Get plugin model provider
:param provider: provider name
:return: provider schema
Get provider schema.
"""
if "/" not in provider:
provider = str(ModelProviderID(provider))
# fetch plugin model providers
plugin_model_provider_entities = self.get_plugin_model_providers()
# get the provider
plugin_model_provider_entity = next(
(p for p in plugin_model_provider_entities if p.declaration.provider == provider),
None,
)
if not plugin_model_provider_entity:
provider_entity = self._resolve_provider(provider)
if provider_entity is None:
raise ValueError(f"Invalid provider: {provider}")
return plugin_model_provider_entity
return provider_entity
def provider_credentials_validate(self, *, provider: str, credentials: dict):
"""
Validate provider credentials
:param provider: provider name
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
:return:
Validate provider credentials.
"""
# fetch plugin model provider
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
provider_entity = self.get_model_provider(provider=provider)
# get provider_credential_schema and validate credentials according to the rules
provider_credential_schema = plugin_model_provider_entity.declaration.provider_credential_schema
provider_credential_schema = provider_entity.provider_credential_schema
if not provider_credential_schema:
raise ValueError(f"Provider {provider} does not have provider_credential_schema")
# validate provider credential schema
validator = ProviderCredentialSchemaValidator(provider_credential_schema)
filtered_credentials = validator.validate_and_filter(credentials)
# validate the credentials, raise exception if validation failed
self.plugin_model_manager.validate_provider_credentials(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_model_provider_entity.plugin_id,
provider=plugin_model_provider_entity.provider,
self.model_runtime.validate_provider_credentials(
provider=provider_entity.provider,
credentials=filtered_credentials,
)
@@ -141,33 +76,20 @@ class ModelProviderFactory:
def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict):
"""
Validate model credentials
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: model credentials, credentials form defined in `model_credential_schema`.
:return:
Validate model credentials.
"""
# fetch plugin model provider
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
provider_entity = self.get_model_provider(provider=provider)
# get model_credential_schema and validate credentials according to the rules
model_credential_schema = plugin_model_provider_entity.declaration.model_credential_schema
model_credential_schema = provider_entity.model_credential_schema
if not model_credential_schema:
raise ValueError(f"Provider {provider} does not have model_credential_schema")
# validate model credential schema
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
filtered_credentials = validator.validate_and_filter(credentials)
# call validate_credentials method of model type to validate credentials, raise exception if validation failed
self.plugin_model_manager.validate_model_credentials(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_model_provider_entity.plugin_id,
provider=plugin_model_provider_entity.provider,
model_type=model_type.value,
self.model_runtime.validate_model_credentials(
provider=provider_entity.provider,
model_type=model_type,
model=model,
credentials=filtered_credentials,
)
@@ -178,65 +100,16 @@ class ModelProviderFactory:
self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None
) -> AIModelEntity | None:
"""
Get model schema
Get model schema.
"""
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
cached_schema_json = None
try:
cached_schema_json = redis_client.get(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to read plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if cached_schema_json:
try:
return AIModelEntity.model_validate_json(cached_schema_json)
except ValidationError:
logger.warning(
"Failed to validate cached plugin model schema for model %s",
model,
exc_info=True,
)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to delete invalid plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
schema = self.plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
provider_entity = self.get_model_provider(provider)
return self.model_runtime.get_model_schema(
provider=provider_entity.provider,
model_type=model_type,
model=model,
credentials=credentials or {},
)
if schema:
try:
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to write plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
return schema
def get_models(
self,
*,
@@ -245,143 +118,56 @@ class ModelProviderFactory:
provider_configs: list[ProviderConfig] | None = None,
) -> list[SimpleProviderEntity]:
"""
Get all models for given model type
:param provider: provider name
:param model_type: model type
:param provider_configs: list of provider configs
:return: list of models
Get all models for given model type.
"""
provider_configs = provider_configs or []
# scan all providers
plugin_model_provider_entities = self.get_plugin_model_providers()
# traverse all model_provider_extensions
providers = []
for plugin_model_provider_entity in plugin_model_provider_entities:
# filter by provider if provider is present
if provider and plugin_model_provider_entity.declaration.provider != provider:
for provider_entity in self.get_model_providers():
if provider and not self._matches_provider(provider_entity, provider):
continue
# get provider schema
provider_schema = plugin_model_provider_entity.declaration
model_types = provider_schema.supported_model_types
if model_type:
if model_type not in model_types:
continue
model_types = [model_type]
all_model_type_models = []
for model_schema in provider_schema.models:
if model_schema.model_type != model_type:
continue
all_model_type_models.append(model_schema)
simple_provider_schema = provider_schema.to_simple_provider()
if model_type:
simple_provider_schema.models = all_model_type_models
if model_type and model_type not in provider_entity.supported_model_types:
continue
simple_provider_schema = provider_entity.to_simple_provider()
if model_type is not None:
simple_provider_schema.models = [
model_schema for model_schema in provider_entity.models if model_schema.model_type == model_type
]
providers.append(simple_provider_schema)
return providers
def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel:
"""
Get model type instance by provider name and model type
:param provider: provider name
:param model_type: model type
:return: model type instance
Get model type instance by provider name and model type.
"""
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
init_params = {
"tenant_id": self.tenant_id,
"plugin_id": plugin_id,
"provider_name": provider_name,
"plugin_model_provider": self.get_plugin_model_provider(provider),
}
provider_schema = self.get_model_provider(provider)
if model_type == ModelType.LLM:
return LargeLanguageModel.model_validate(init_params)
elif model_type == ModelType.TEXT_EMBEDDING:
return TextEmbeddingModel.model_validate(init_params)
elif model_type == ModelType.RERANK:
return RerankModel.model_validate(init_params)
elif model_type == ModelType.SPEECH2TEXT:
return Speech2TextModel.model_validate(init_params)
elif model_type == ModelType.MODERATION:
return ModerationModel.model_validate(init_params)
elif model_type == ModelType.TTS:
return TTSModel.model_validate(init_params)
return LargeLanguageModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
if model_type == ModelType.TEXT_EMBEDDING:
return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
if model_type == ModelType.RERANK:
return RerankModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
if model_type == ModelType.SPEECH2TEXT:
return Speech2TextModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
if model_type == ModelType.MODERATION:
return ModerationModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
if model_type == ModelType.TTS:
return TTSModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
raise ValueError(f"Unsupported model type: {model_type}")
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
"""
Get provider icon
:param provider: provider name
:param icon_type: icon type (icon_small or icon_small_dark)
:param lang: language (zh_Hans or en_US)
:return: provider icon
Get provider icon.
"""
# get the provider schema
provider_schema = self.get_provider_schema(provider)
provider_entity = self.get_model_provider(provider)
return self.model_runtime.get_provider_icon(provider=provider_entity.provider, icon_type=icon_type, lang=lang)
if icon_type.lower() == "icon_small":
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
def _resolve_provider(self, provider: str) -> ProviderEntity | None:
return next((item for item in self.get_model_providers() if self._matches_provider(item, provider)), None)
if lang.lower() == "zh_hans":
file_name = provider_schema.icon_small.zh_Hans
else:
file_name = provider_schema.icon_small.en_US
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
if lang.lower() == "zh_hans":
file_name = provider_schema.icon_small_dark.zh_Hans
else:
file_name = provider_schema.icon_small_dark.en_US
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")
if not file_name:
raise ValueError(f"Provider {provider} does not have icon.")
image_mime_types = {
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"png": "image/png",
"gif": "image/gif",
"bmp": "image/bmp",
"tiff": "image/tiff",
"tif": "image/tiff",
"webp": "image/webp",
"svg": "image/svg+xml",
"ico": "image/vnd.microsoft.icon",
"heif": "image/heif",
"heic": "image/heic",
}
extension = file_name.split(".")[-1]
mime_type = image_mime_types.get(extension, "image/png")
# get icon bytes from plugin asset manager
from core.plugin.impl.asset import PluginAssetManager
plugin_asset_manager = PluginAssetManager()
return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
def get_plugin_id_and_provider_name_from_provider(self, provider: str) -> tuple[str, str]:
"""
Get plugin id and provider name from provider name
:param provider: provider name
:return: plugin id and provider name
"""
provider_id = ModelProviderID(provider)
return provider_id.plugin_id, provider_id.provider_name
@staticmethod
def _matches_provider(provider_entity: ProviderEntity, provider: str) -> bool:
return provider in (provider_entity.provider, provider_entity.provider_name)

View File

@@ -0,0 +1,159 @@
from __future__ import annotations
from collections.abc import Generator, Iterable, Sequence
from typing import IO, Any, Protocol, Union, runtime_checkable
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
@runtime_checkable
class ModelRuntime(Protocol):
"""Port for provider discovery, schema lookup, and model execution.
`provider` is the model runtime's canonical provider identifier. Adapters may
derive transport-specific details from it, but those details stay outside
this boundary.
"""
def fetch_model_providers(self) -> Sequence[ProviderEntity]: ...
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: ...
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: ...
def validate_model_credentials(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> None: ...
def get_model_schema(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> AIModelEntity | None: ...
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: bool,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: ...
def get_llm_num_tokens(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: Sequence[PromptMessageTool] | None,
) -> int: ...
def invoke_text_embedding(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
texts: list[str],
input_type: EmbeddingInputType,
) -> EmbeddingResult: ...
def invoke_multimodal_embedding(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
documents: list[dict[str, Any]],
input_type: EmbeddingInputType,
) -> EmbeddingResult: ...
def get_text_embedding_num_tokens(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
texts: list[str],
) -> list[int]: ...
def invoke_rerank(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
query: str,
docs: list[str],
score_threshold: float | None,
top_n: int | None,
) -> RerankResult: ...
def invoke_multimodal_rerank(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None,
top_n: int | None,
) -> RerankResult: ...
def invoke_tts(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
content_text: str,
voice: str,
) -> Iterable[bytes]: ...
def get_tts_model_voices(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
language: str | None,
) -> Any: ...
def invoke_speech_to_text(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
file: IO[bytes],
) -> str: ...
def invoke_moderation(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
text: str,
) -> bool: ...

View File

@@ -6,13 +6,12 @@ from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod
from types import MappingProxyType
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
from uuid import uuid4
from dify_graph.entities import GraphInitParams
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import (
ErrorStrategy,
NodeExecutionType,
@@ -60,7 +59,7 @@ from dify_graph.node_events import (
StreamCompletedEvent,
)
from dify_graph.runtime import GraphRuntimeState
from libs.datetime_utils import naive_utc_now
from dify_graph.utils.datetime_utils import naive_utc_now
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
_MISSING_RUN_CONTEXT_VALUE = object()
@@ -68,23 +67,6 @@ _MISSING_RUN_CONTEXT_VALUE = object()
logger = logging.getLogger(__name__)
class DifyRunContextProtocol(Protocol):
tenant_id: str
app_id: str
user_id: str
user_from: Any
invoke_from: Any
class _MappingDifyRunContext:
def __init__(self, mapping: Mapping[str, Any]) -> None:
self.tenant_id = str(mapping["tenant_id"])
self.app_id = str(mapping["app_id"])
self.user_id = str(mapping["user_id"])
self.user_from = mapping["user_from"]
self.invoke_from = mapping["invoke_from"]
class Node(Generic[NodeDataT]):
"""BaseNode serves as the foundational class for all node implementations.
@@ -177,8 +159,9 @@ class Node(Generic[NodeDataT]):
# Skip base class itself
if cls is Node:
return
# Only register production node implementations defined under the
# canonical workflow namespaces.
# Only treat nodes from the base dify_graph package as production
# registrations. Higher-layer packages may still register subclasses,
# but dify_graph itself should not know their module identities.
# This prevents test helper subclasses from polluting the global registry and
# accidentally overriding real node types (e.g., a test Answer node).
module_name = getattr(cls, "__module__", "")
@@ -186,7 +169,7 @@ class Node(Generic[NodeDataT]):
node_type = cls.node_type
version = cls.version()
bucket = Node._registry.setdefault(node_type, {})
if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")):
if module_name.startswith("dify_graph.nodes."):
# Production node definitions take precedence and may override
bucket[version] = cls # type: ignore[index]
else:
@@ -299,25 +282,6 @@ class Node(Generic[NodeDataT]):
raise ValueError(f"run_context missing required key: {key}")
return value
def require_dify_context(self) -> DifyRunContextProtocol:
raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)
if raw_ctx is None:
raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}")
if isinstance(raw_ctx, Mapping):
missing_keys = [
key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx
]
if missing_keys:
raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}")
return _MappingDifyRunContext(raw_ctx)
for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"):
if not hasattr(raw_ctx, attr):
raise TypeError(f"invalid dify context object, missing attribute: {attr}")
return cast(DifyRunContextProtocol, raw_ctx)
@property
def execution_id(self) -> str:
return self._node_execution_id
@@ -793,16 +757,11 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
retriever_resources = [
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
]
return NodeRunRetrieverResourceEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
retriever_resources=retriever_resources,
retriever_resources=event.retriever_resources,
context=event.context,
node_version=self.version(),
)

View File

@@ -11,9 +11,13 @@ from dify_graph.nodes.base import variable_template_parser
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.http_request.executor import Executor
from dify_graph.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol
from dify_graph.nodes.protocols import (
FileManagerProtocol,
FileReferenceFactoryProtocol,
HttpClientProtocol,
ToolFileManagerProtocol,
)
from dify_graph.variables.segments import ArrayFileSegment
from factories import file_factory
from .config import build_http_request_config, resolve_http_request_config
from .entities import (
@@ -46,6 +50,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
http_client: HttpClientProtocol,
tool_file_manager_factory: Callable[[], ToolFileManagerProtocol],
file_manager: FileManagerProtocol,
file_reference_factory: FileReferenceFactoryProtocol,
) -> None:
super().__init__(
id=id,
@@ -58,6 +63,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
self._http_client = http_client
self._tool_file_manager_factory = tool_file_manager_factory
self._file_manager = file_manager
self._file_reference_factory = file_reference_factory
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@@ -212,7 +218,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
"""
Extract files from response by checking both Content-Type header and URL
"""
dify_ctx = self.require_dify_context()
files: list[File] = []
is_file = response.is_file
content_type = response.content_type
@@ -237,20 +242,16 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
tool_file_manager = self._tool_file_manager_factory()
tool_file = tool_file_manager.create_file_by_raw(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
conversation_id=None,
file_binary=content,
mimetype=mime_type,
)
mapping = {
"tool_file_id": tool_file.id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=dify_ctx.tenant_id,
file = self._file_reference_factory.build_from_mapping(
mapping={
"tool_file_id": tool_file.id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
)
files.append(file)

View File

@@ -15,15 +15,16 @@ from dify_graph.node_events import (
from dify_graph.node_events.base import NodeEventBase
from dify_graph.node_events.node import StreamCompletedEvent
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.runtime import HumanInputNodeRuntimeProtocol
from dify_graph.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRepository,
)
from dify_graph.utils.datetime_utils import naive_utc_now
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.datetime_utils import naive_utc_now
from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient
from .entities import DeliveryChannelConfig, HumanInputNodeData
from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType
if TYPE_CHECKING:
@@ -68,6 +69,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
form_repository: HumanInputFormRepository,
runtime: HumanInputNodeRuntimeProtocol | None = None,
) -> None:
super().__init__(
id=id,
@@ -76,6 +78,9 @@ class HumanInputNode(Node[HumanInputNodeData]):
graph_runtime_state=graph_runtime_state,
)
self._form_repository = form_repository
if runtime is None:
raise ValueError("runtime is required")
self._runtime = runtime
@classmethod
def version(cls) -> str:
@@ -171,25 +176,14 @@ class HumanInputNode(Node[HumanInputNodeData]):
return self._node_data.is_webapp_enabled()
def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]:
dify_ctx = self.require_dify_context()
invoke_from = self._invoke_from_value()
enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled]
if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}:
enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP]
return [
apply_debug_email_recipient(
method,
enabled=invoke_from == _INVOKE_FROM_DEBUGGER,
user_id=dify_ctx.user_id,
)
for method in enabled_methods
]
return self._runtime.apply_delivery_runtime(methods=enabled_methods)
def _invoke_from_value(self) -> str:
invoke_from = self.require_dify_context().invoke_from
if isinstance(invoke_from, str):
return invoke_from
return str(getattr(invoke_from, "value", invoke_from))
return self._runtime.invoke_source()
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
node_data = self._node_data
@@ -224,11 +218,9 @@ class HumanInputNode(Node[HumanInputNodeData]):
"""
repo = self._form_repository
form = repo.get_form(self._workflow_execution_id, self.id)
dify_ctx = self.require_dify_context()
if form is None:
display_in_ui = self._display_in_ui()
params = FormCreateParams(
app_id=dify_ctx.app_id,
workflow_execution_id=self._workflow_execution_id,
node_id=self.id,
form_config=self._node_data,
@@ -238,7 +230,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
resolved_default_values=self.resolve_default_values(),
console_recipient_required=self._should_require_console_recipient(),
console_creator_account_id=(
dify_ctx.user_id
self._runtime.console_actor_id()
if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}
else None
),

View File

@@ -34,10 +34,10 @@ from dify_graph.nodes.base import LLMUsageTrackingMixin
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from dify_graph.runtime import VariablePool
from dify_graph.utils.datetime_utils import naive_utc_now
from dify_graph.variables import IntegerVariable, NoneSegment
from dify_graph.variables.segments import ArrayAnySegment, ArraySegment
from dify_graph.variables.variables import Variable
from libs.datetime_utils import naive_utc_now
from .exc import (
InvalidIteratorValueError,

View File

@@ -3,11 +3,11 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
class ModelConfig(BaseModel):

View File

@@ -2,10 +2,8 @@ import mimetypes
import typing as tp
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
from dify_graph.file import File, FileTransferMethod, FileType
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol
class LLMFileSaver(tp.Protocol):
@@ -57,17 +55,20 @@ class LLMFileSaver(tp.Protocol):
class FileSaverImpl(LLMFileSaver):
_tenant_id: str
_user_id: str
_tool_file_manager: ToolFileManagerProtocol
_file_reference_factory: FileReferenceFactoryProtocol
def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol):
self._user_id = user_id
self._tenant_id = tenant_id
def __init__(
self,
*,
tool_file_manager: ToolFileManagerProtocol,
file_reference_factory: FileReferenceFactoryProtocol,
http_client: HttpClientProtocol,
):
self._tool_file_manager = tool_file_manager
self._file_reference_factory = file_reference_factory
self._http_client = http_client
def _get_tool_file_manager(self):
return ToolFileManager()
def save_remote_url(self, url: str, file_type: FileType) -> File:
http_response = self._http_client.get(url)
http_response.raise_for_status()
@@ -83,10 +84,7 @@ class FileSaverImpl(LLMFileSaver):
file_type: FileType,
extension_override: str | None = None,
) -> File:
tool_file_manager = self._get_tool_file_manager()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self._user_id,
tenant_id=self._tenant_id,
tool_file = self._tool_file_manager.create_file_by_raw(
# TODO(QuantumGhost): what is conversation id?
conversation_id=None,
file_binary=data,
@@ -94,19 +92,18 @@ class FileSaverImpl(LLMFileSaver):
)
extension_override = _validate_extension_override(extension_override)
extension = _get_extension(mime_type, extension_override)
url = sign_tool_file(tool_file.id, extension)
return File(
tenant_id=self._tenant_id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
filename=tool_file.name,
extension=extension,
mime_type=mime_type,
size=len(data),
related_id=tool_file.id,
url=url,
storage_key=tool_file.file_key,
return self._file_reference_factory.build_from_mapping(
mapping={
"type": file_type,
"transfer_method": FileTransferMethod.TOOL_FILE,
"filename": tool_file.name,
"extension": extension,
"mime_type": mime_type,
"size": len(data),
"tool_file_id": tool_file.id,
"related_id": tool_file.id,
"storage_key": tool_file.file_key,
}
)

View File

@@ -1,9 +1,8 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, cast
from collections.abc import Mapping, Sequence
from typing import Any, Protocol, TypeAlias, cast
from core.model_manager import ModelInstance
from dify_graph.file import FileType, file_manager
from dify_graph.file.models import File
from dify_graph.model_runtime.entities import (
@@ -35,15 +34,36 @@ from .exc import (
TemplateTypeNotSupportError,
)
from .protocols import TemplateRenderer
from .runtime_protocols import PreparedLLMProtocol
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
model_instance.model_name,
dict(model_instance.credentials),
)
class _LegacyModelInstance(Protocol):
model_type_instance: object
model_name: str
credentials: object
parameters: Mapping[str, Any]
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: ...
PreparedModelInstance: TypeAlias = PreparedLLMProtocol | _LegacyModelInstance
def fetch_model_schema(*, model_instance: PreparedModelInstance) -> AIModelEntity:
get_model_schema = getattr(model_instance, "get_model_schema", None)
if callable(get_model_schema):
model_schema = cast(PreparedLLMProtocol, model_instance).get_model_schema()
else:
legacy_model_instance = cast(_LegacyModelInstance, model_instance)
credentials = legacy_model_instance.credentials
if isinstance(credentials, Mapping):
credentials = dict(credentials)
model_schema = cast(LargeLanguageModel, legacy_model_instance.model_type_instance).get_model_schema(
legacy_model_instance.model_name,
credentials,
)
if not model_schema:
raise ValueError(f"Model schema not found for {model_instance.model_name}")
raise ValueError(f"Model schema not found for {getattr(model_instance, 'model_name', 'unknown model')}")
return model_schema
@@ -116,7 +136,7 @@ def fetch_prompt_messages(
sys_files: Sequence[File],
context: str | None = None,
memory: PromptMessageMemory | None = None,
model_instance: ModelInstance,
model_instance: PreparedModelInstance,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
stop: Sequence[str] | None = None,
memory_config: MemoryConfig | None = None,
@@ -391,7 +411,7 @@ def combine_message_content_with_role(
raise NotImplementedError(f"Role {role} is not supported")
def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int:
def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: PreparedModelInstance) -> int:
rest_tokens = 2000
runtime_model_schema = fetch_model_schema(model_instance=model_instance)
runtime_model_parameters = model_instance.parameters
@@ -421,7 +441,7 @@ def handle_memory_chat_mode(
*,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
model_instance: PreparedModelInstance,
) -> Sequence[PromptMessage]:
if not memory or not memory_config:
return []
@@ -436,7 +456,7 @@ def handle_memory_completion_mode(
*,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
model_instance: PreparedModelInstance,
) -> str:
if not memory or not memory_config:
return ""

View File

@@ -7,16 +7,8 @@ import logging
import re
import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, cast
from sqlalchemy import select
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.tools.signature import sign_upload_file
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
@@ -27,10 +19,11 @@ from dify_graph.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File, FileTransferMethod, FileType
from dify_graph.file import File, FileType, file_manager
from dify_graph.model_runtime.entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
)
from dify_graph.model_runtime.entities.llm_entities import (
@@ -41,7 +34,14 @@ from dify_graph.model_runtime.entities.llm_entities import (
LLMStructuredOutput,
LLMUsage,
)
from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContentUnionTypes,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import (
@@ -55,19 +55,23 @@ from dify_graph.node_events import (
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.llm.runtime_protocols import (
PreparedLLMProtocol,
PromptMessageSerializerProtocol,
RetrieverAttachmentLoaderProtocol,
)
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from dify_graph.runtime import VariablePool
from dify_graph.template_rendering import Jinja2TemplateRenderer, TemplateRenderError
from dify_graph.variables import (
ArrayFileSegment,
ArraySegment,
FileSegment,
NoneSegment,
ObjectSegment,
StringSegment,
)
from extensions.ext_database import db
from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from . import llm_utils
from .entities import (
@@ -79,9 +83,12 @@ from .exc import (
InvalidContextStructureError,
InvalidVariableTypeError,
LLMNodeError,
MemoryRolePrefixRequiredError,
NoPromptFoundError,
TemplateTypeNotSupportError,
VariableNotFoundError,
)
from .file_saver import FileSaverImpl, LLMFileSaver
from .file_saver import LLMFileSaver
if TYPE_CHECKING:
from dify_graph.file.models import File
@@ -101,11 +108,11 @@ class LLMNode(Node[LLMNodeData]):
_file_outputs: list[File]
_llm_file_saver: LLMFileSaver
_credentials_provider: CredentialsProvider
_model_factory: ModelFactory
_model_instance: ModelInstance
_retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None
_prompt_message_serializer: PromptMessageSerializerProtocol
_jinja2_template_renderer: Jinja2TemplateRenderer | None
_model_instance: PreparedLLMProtocol
_memory: PromptMessageMemory | None
_template_renderer: TemplateRenderer
def __init__(
self,
@@ -114,13 +121,15 @@ class LLMNode(Node[LLMNodeData]):
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
model_instance: ModelInstance,
credentials_provider: object | None = None,
model_factory: object | None = None,
model_instance: PreparedLLMProtocol,
http_client: HttpClientProtocol,
template_renderer: TemplateRenderer,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
llm_file_saver: LLMFileSaver,
prompt_message_serializer: PromptMessageSerializerProtocol,
retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None = None,
jinja2_template_renderer: Jinja2TemplateRenderer | None = None,
):
super().__init__(
id=id,
@@ -131,20 +140,14 @@ class LLMNode(Node[LLMNodeData]):
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
self._credentials_provider = credentials_provider
self._model_factory = model_factory
_ = credentials_provider, model_factory, http_client
self._model_instance = model_instance
self._memory = memory
self._template_renderer = template_renderer
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
http_client=http_client,
)
self._llm_file_saver = llm_file_saver
self._prompt_message_serializer = prompt_message_serializer
self._retriever_attachment_loader = retriever_attachment_loader
self._jinja2_template_renderer = jinja2_template_renderer
@classmethod
def version(cls) -> str:
@@ -230,7 +233,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
context_files=context_files,
template_renderer=self._template_renderer,
jinja2_template_renderer=self._jinja2_template_renderer,
)
# handle invoke result
@@ -238,7 +241,6 @@ class LLMNode(Node[LLMNodeData]):
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.require_dify_context().user_id,
structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=self.node_data.structured_output,
file_saver=self._llm_file_saver,
@@ -281,7 +283,7 @@ class LLMNode(Node[LLMNodeData]):
process_data = {
"model_mode": self.node_data.model.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
"prompts": self._prompt_message_serializer.serialize(
model_mode=self.node_data.model.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
@@ -349,10 +351,9 @@ class LLMNode(Node[LLMNodeData]):
@staticmethod
def invoke_llm(
*,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None = None,
user_id: str,
structured_output_enabled: bool,
structured_output: Mapping[str, Any] | None = None,
file_saver: LLMFileSaver,
@@ -363,35 +364,28 @@ class LLMNode(Node[LLMNodeData]):
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
model_parameters = model_instance.parameters
invoke_model_parameters = dict(model_parameters)
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if structured_output_enabled:
output_schema = LLMNode.fetch_structured_output_schema(
structured_output=structured_output or {},
)
request_start_time = time.perf_counter()
invoke_result = invoke_llm_with_structured_output(
provider=model_instance.provider,
model_schema=model_schema,
model_instance=model_instance,
invoke_result = model_instance.invoke_llm_with_structured_output(
prompt_messages=prompt_messages,
json_schema=output_schema,
model_parameters=invoke_model_parameters,
stop=list(stop or []),
stop=stop,
stream=True,
user=user_id,
)
else:
request_start_time = time.perf_counter()
invoke_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
prompt_messages=prompt_messages,
model_parameters=invoke_model_parameters,
stop=list(stop or []),
tools=None,
stop=stop,
stream=True,
user=user_id,
)
return LLMNode.handle_invoke_result(
@@ -400,6 +394,7 @@ class LLMNode(Node[LLMNodeData]):
file_outputs=file_outputs,
node_id=node_id,
node_type=node_type,
model_instance=model_instance,
reasoning_format=reasoning_format,
request_start_time=request_start_time,
)
@@ -412,6 +407,7 @@ class LLMNode(Node[LLMNodeData]):
file_outputs: list[File],
node_id: str,
node_type: NodeType,
model_instance: PreparedLLMProtocol | object,
reasoning_format: Literal["separated", "tagged"] = "tagged",
request_start_time: float | None = None,
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
@@ -483,8 +479,14 @@ class LLMNode(Node[LLMNodeData]):
usage = result.delta.usage
if finish_reason is None and result.delta.finish_reason:
finish_reason = result.delta.finish_reason
except OutputParserError as e:
raise LLMNodeError(f"Failed to parse structured output: {e}")
except Exception as e:
if hasattr(model_instance, "is_structured_output_parse_error") and cast(
PreparedLLMProtocol, model_instance
).is_structured_output_parse_error(e):
raise LLMNodeError(f"Failed to parse structured output: {e}") from e
if type(e).__name__ == "OutputParserError":
raise LLMNodeError(f"Failed to parse structured output: {e}") from e
raise
# Extract reasoning content from <think> tags in the main text
full_text = full_text_buffer.getvalue()
@@ -687,30 +689,8 @@ class LLMNode(Node[LLMNodeData]):
segment_id = retriever_resource.get("segment_id")
if not segment_id:
continue
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == segment_id,
)
).all()
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attachment_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.require_dify_context().tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
)
context_files.append(attachment_info)
if self._retriever_attachment_loader is not None:
context_files.extend(self._retriever_attachment_loader.load(segment_id=segment_id))
yield RunRetrieverResourceEvent(
retriever_resources=original_retriever_resource,
context=context_str.strip(),
@@ -755,7 +735,7 @@ class LLMNode(Node[LLMNodeData]):
sys_files: Sequence[File],
context: str | None = None,
memory: PromptMessageMemory | None = None,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
stop: Sequence[str] | None = None,
memory_config: MemoryConfig | None = None,
@@ -764,24 +744,186 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
context_files: list[File] | None = None,
template_renderer: TemplateRenderer | None = None,
jinja2_template_renderer: Jinja2TemplateRenderer | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
return llm_utils.fetch_prompt_messages(
sys_query=sys_query,
sys_files=sys_files,
context=context,
memory=memory,
model_instance=model_instance,
prompt_template=prompt_template,
stop=stop,
memory_config=memory_config,
vision_enabled=vision_enabled,
vision_detail=vision_detail,
variable_pool=variable_pool,
jinja2_variables=jinja2_variables,
context_files=context_files,
template_renderer=template_renderer,
)
prompt_messages: list[PromptMessage] = []
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
LLMNode.handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail,
jinja2_template_renderer=jinja2_template_renderer,
)
)
# Get memory messages for chat mode
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if sys_query:
message = LLMNodeChatModelMessage(
text=sys_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
LLMNode.handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=vision_detail,
jinja2_template_renderer=jinja2_template_renderer,
)
)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
# For completion model
prompt_messages.extend(
_handle_completion_template(
template=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
jinja2_template_renderer=jinja2_template_renderer,
)
)
# Get memory text for completion model
memory_text = _handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
# For issue #11247 - Check if prompt content is a string or a list
if isinstance(prompt_content, str):
prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
if "#histories#" in content_item.data:
content_item.data = content_item.data.replace("#histories#", memory_text)
else:
content_item.data = memory_text + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
# Add current query to the prompt message
if sys_query:
if isinstance(prompt_content, str):
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
# The sys_files will be deprecated later
if vision_enabled and sys_files:
file_prompts = []
for file in sys_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# The context_files
if vision_enabled and context_files:
file_prompts = []
for file in context_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Remove empty messages and filter unsupported content
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
prompt_message_content: list[PromptMessageContentUnionTypes] = []
for content_item in prompt_message.content:
# Skip content if features are not defined
if not model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
prompt_message_content.append(content_item)
continue
# Skip content if corresponding feature is not supported
if (
(
content_item.type == PromptMessageContentType.IMAGE
and ModelFeature.VISION not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.DOCUMENT
and ModelFeature.DOCUMENT not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.VIDEO
and ModelFeature.VIDEO not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.AUDIO
and ModelFeature.AUDIO not in model_schema.features
)
):
continue
prompt_message_content.append(content_item)
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
prompt_message.content = prompt_message_content[0].data
else:
prompt_message.content = prompt_message_content
if prompt_message.is_empty():
continue
filtered_prompt_messages.append(prompt_message)
if len(filtered_prompt_messages) == 0:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
return filtered_prompt_messages, stop
@classmethod
def _extract_variable_selector_to_variable_mapping(
@@ -881,16 +1023,61 @@ class LLMNode(Node[LLMNodeData]):
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
template_renderer: TemplateRenderer | None = None,
jinja2_template_renderer: Jinja2TemplateRenderer | None = None,
) -> Sequence[PromptMessage]:
return llm_utils.handle_list_messages(
messages=messages,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail_config,
template_renderer=template_renderer,
)
prompt_messages: list[PromptMessage] = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
jinja2_template_renderer=jinja2_template_renderer,
)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
elif isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
@staticmethod
def handle_blocking_result(
@@ -1027,5 +1214,153 @@ class LLMNode(Node[LLMNodeData]):
return self.node_data.retry_config.retry_enabled
@property
def model_instance(self) -> ModelInstance:
def model_instance(self) -> PreparedLLMProtocol:
return self._model_instance
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
):
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=contents)
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=contents)
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=contents)
case _:
raise NotImplementedError(f"Role {role} is not supported")
def _render_jinja2_message(
*,
template: str,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
jinja2_template_renderer: Jinja2TemplateRenderer | None,
):
if not template:
return ""
jinja2_inputs = {}
for jinja2_variable in jinja2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
if jinja2_template_renderer is None:
raise TemplateRenderError("LLMNode requires an injected jinja2_template_renderer for jinja2 prompts.")
return jinja2_template_renderer.render_template(template, jinja2_inputs)
def _calculate_rest_token(
*,
prompt_messages: list[PromptMessage],
model_instance: PreparedLLMProtocol,
) -> int:
rest_tokens = 2000
runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
runtime_model_parameters = model_instance.parameters
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in runtime_model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
runtime_model_parameters.get(parameter_rule.name)
or runtime_model_parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _handle_memory_chat_mode(
*,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: PreparedLLMProtocol,
) -> Sequence[PromptMessage]:
memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
return memory_messages
def _handle_memory_completion_mode(
*,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: PreparedLLMProtocol,
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = llm_utils.fetch_memory_text(
memory=memory,
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
jinja2_template_renderer: Jinja2TemplateRenderer | None = None,
) -> Sequence[PromptMessage]:
"""Handle completion template processing outside of LLMNode class.
Args:
template: The completion model prompt template
context: Optional context string
jinja2_variables: Variables for jinja2 template rendering
variable_pool: Variable pool for template conversion
Returns:
Sequence of prompt messages
"""
prompt_messages = []
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
jinja2_template_renderer=jinja2_template_renderer,
)
else:
if context:
template_text = template.text.replace("{#context#}", context)
else:
template_text = template.text
result_text = variable_pool.convert_template(template_text).text
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER
)
prompt_messages.append(prompt_message)
return prompt_messages

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Protocol
from core.model_manager import ModelInstance
from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol
class CredentialsProvider(Protocol):
@@ -15,10 +15,10 @@ class CredentialsProvider(Protocol):
class ModelFactory(Protocol):
"""Port for creating initialized LLM model instances for execution."""
"""Port for creating prepared graph-facing LLM runtimes for execution."""
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
"""Create a model instance that is ready for schema lookup and invocation."""
def init_model_instance(self, provider_name: str, model_name: str) -> PreparedLLMProtocol:
"""Create a prepared LLM runtime that is ready for graph execution."""
...

View File

@@ -0,0 +1,74 @@
from __future__ import annotations
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Protocol
from dify_graph.file import File
from dify_graph.model_runtime.entities import LLMMode, PromptMessage
from dify_graph.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
)
from dify_graph.model_runtime.entities.message_entities import PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
class PreparedLLMProtocol(Protocol):
"""A graph-facing LLM runtime with provider-specific setup already applied."""
@property
def provider(self) -> str: ...
@property
def model_name(self) -> str: ...
@property
def parameters(self) -> Mapping[str, Any]: ...
@property
def stop(self) -> Sequence[str] | None: ...
def get_model_schema(self) -> AIModelEntity: ...
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: ...
def invoke_llm(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: Mapping[str, Any],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: bool,
) -> LLMResult | Generator[LLMResultChunk, None, None]: ...
def invoke_llm_with_structured_output(
self,
*,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping[str, Any],
stop: Sequence[str] | None,
stream: bool,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def is_structured_output_parse_error(self, error: Exception) -> bool: ...
class PromptMessageSerializerProtocol(Protocol):
"""Port for converting compiled prompt messages into persisted process data."""
def serialize(
self,
*,
model_mode: LLMMode,
prompt_messages: Sequence[PromptMessage],
) -> Any: ...
class RetrieverAttachmentLoaderProtocol(Protocol):
"""Port for resolving retriever segment attachments into graph file references."""
def load(self, *, segment_id: str) -> Sequence[File]: ...

View File

@@ -31,9 +31,9 @@ from dify_graph.nodes.base import LLMUsageTrackingMixin
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
from dify_graph.utils.condition.processor import ConditionProcessor
from dify_graph.utils.datetime_utils import naive_utc_now
from dify_graph.variables import Segment, SegmentType
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
from libs.datetime_utils import naive_utc_now
if TYPE_CHECKING:
from dify_graph.graph_engine import GraphEngine

View File

@@ -7,10 +7,10 @@ from pydantic import (
field_validator,
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
from dify_graph.prompt_entities import MemoryConfig
from dify_graph.variables.types import SegmentType
_OLD_BOOL_TYPE_NAME = "bool"

View File

@@ -5,11 +5,6 @@ import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from core.model_manager import ModelInstance
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
BuiltinNodeTypes,
@@ -17,8 +12,8 @@ from dify_graph.enums import (
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File
from dify_graph.model_runtime.entities import ImagePromptMessageContent
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@@ -27,14 +22,15 @@ from dify_graph.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base import variable_template_parser
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.llm import llm_utils
from dify_graph.nodes.llm import LLMNode, llm_utils
from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol
from dify_graph.runtime import VariablePool
from dify_graph.variables.types import ArrayValidation, SegmentType
from factories.variable_factory import build_segment_with_type
@@ -66,7 +62,6 @@ logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.runtime import GraphRuntimeState
@@ -99,9 +94,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR
_model_instance: ModelInstance
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_model_instance: PreparedLLMProtocol
_prompt_message_serializer: PromptMessageSerializerProtocol
_memory: PromptMessageMemory | None
def __init__(
@@ -111,10 +105,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
credentials_provider: object | None = None,
model_factory: object | None = None,
model_instance: PreparedLLMProtocol,
memory: PromptMessageMemory | None = None,
prompt_message_serializer: PromptMessageSerializerProtocol,
) -> None:
super().__init__(
id=id,
@@ -122,9 +117,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._credentials_provider = credentials_provider
self._model_factory = model_factory
_ = credentials_provider, model_factory
self._model_instance = model_instance
self._prompt_message_serializer = prompt_message_serializer
self._memory = memory
@classmethod
@@ -164,13 +159,12 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
)
model_instance = self._model_instance
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise InvalidModelTypeError("Model is not a Large Language Model")
try:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
if model_schema.model_type != ModelType.LLM:
raise InvalidModelTypeError("Model is not a Large Language Model")
memory = self._memory
if (
@@ -210,8 +204,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
process_data = {
"model_mode": node_data.model.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=node_data.model.mode, prompt_messages=prompt_messages
"prompts": self._prompt_message_serializer.serialize(
model_mode=node_data.model.mode,
prompt_messages=prompt_messages,
),
"usage": None,
"function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
@@ -287,18 +282,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
def _invoke(
self,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
stop: Sequence[str],
stop: Sequence[str] | None,
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=dict(model_instance.parameters),
tools=tools,
stop=list(stop),
stream=False,
user=self.require_dify_context().user_id,
invoke_result = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=dict(model_instance.parameters),
tools=tools or None,
stop=stop,
stream=False,
),
)
# handle invoke result
@@ -317,7 +314,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@@ -329,7 +326,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
content=query, structure=json.dumps(node_data.get_parameter_json_schema())
)
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query,
@@ -340,15 +336,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
prompt_template = self._get_function_calling_prompt_template(
node_data, query, variable_pool, memory, rest_token
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=files,
context="",
memory_config=node_data.memory,
memory=None,
prompt_messages = self._compile_prompt_messages(
model_instance=model_instance,
prompt_template=prompt_template,
files=files,
vision_enabled=node_data.vision.enabled,
image_detail_config=vision_detail,
)
@@ -405,7 +397,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@@ -413,9 +405,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
Generate prompt engineering prompt.
"""
model_mode = ModelMode(data.model.mode)
if model_mode == ModelMode.COMPLETION:
if data.model.mode == LLMMode.COMPLETION:
return self._generate_prompt_engineering_completion_prompt(
node_data=data,
query=query,
@@ -425,7 +415,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
files=files,
vision_detail=vision_detail,
)
elif model_mode == ModelMode.CHAT:
if data.model.mode == LLMMode.CHAT:
return self._generate_prompt_engineering_chat_prompt(
node_data=data,
query=query,
@@ -435,15 +425,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
files=files,
vision_detail=vision_detail,
)
else:
raise InvalidModelModeError(f"Invalid model mode: {model_mode}")
raise InvalidModelModeError(f"Invalid model mode: {data.model.mode}")
def _generate_prompt_engineering_completion_prompt(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@@ -451,7 +440,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
Generate completion prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query,
@@ -462,27 +450,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={"structure": json.dumps(node_data.get_parameter_json_schema())},
query="",
files=files,
context="",
memory_config=node_data.memory,
# AdvancedPromptTransform is still typed against TokenBufferMemory.
memory=cast(Any, memory),
return self._compile_prompt_messages(
model_instance=model_instance,
prompt_template=prompt_template,
files=files,
vision_enabled=node_data.vision.enabled,
image_detail_config=vision_detail,
)
return prompt_messages
def _generate_prompt_engineering_chat_prompt(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@@ -490,7 +471,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
Generate chat prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query,
@@ -508,15 +488,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
max_token_limit=rest_token,
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=files,
context="",
memory_config=node_data.memory,
memory=None,
prompt_messages = self._compile_prompt_messages(
model_instance=model_instance,
prompt_template=prompt_template,
files=files,
vision_enabled=node_data.vision.enabled,
image_detail_config=vision_detail,
)
@@ -717,8 +693,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
variable_pool: VariablePool,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode)
) -> list[LLMNodeChatModelMessage]:
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text
@@ -727,15 +702,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
if node_data.model.mode == LLMMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]
else:
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.")
def _get_prompt_engineering_prompt_template(
self,
@@ -744,8 +718,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
variable_pool: VariablePool,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text
@@ -754,64 +727,53 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
if node_data.model.mode == LLMMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]
elif model_mode == ModelMode.COMPLETION:
return CompletionModelPromptTemplate(
if node_data.model.mode == LLMMode.COMPLETION:
return LLMNodeCompletionModelPromptTemplate(
text=COMPLETION_GENERATE_JSON_PROMPT.format(
histories=memory_str, text=input_text, instruction=instruction
)
.replace("{γγγ", "")
.replace("}γγγ", "")
.replace("{ structure }", json.dumps(node_data.get_parameter_json_schema())),
)
else:
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.")
def _calculate_rest_token(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
context: str | None,
) -> int:
try:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
else:
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
prompt_messages = self._compile_prompt_messages(
model_instance=model_instance,
prompt_template=prompt_template,
files=[],
vision_enabled=False,
context=context,
)
rest_tokens = 2000
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
curr_message_tokens = (
model_type_instance.get_num_tokens(
model_instance.model_name, model_instance.credentials, prompt_messages
)
+ 1000
) # add 1000 to ensure tool call messages
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + 1000
max_tokens = 0
for parameter_rule in model_schema.parameter_rules:
@@ -828,8 +790,34 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return rest_tokens
def _compile_prompt_messages(
self,
*,
model_instance: PreparedLLMProtocol,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
files: Sequence[File],
vision_enabled: bool,
context: str | None = "",
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
prompt_messages, _ = LLMNode.fetch_prompt_messages(
sys_query="",
sys_files=files,
context=context,
memory=None,
model_instance=model_instance,
prompt_template=prompt_template,
stop=model_instance.stop,
memory_config=None,
vision_enabled=vision_enabled,
vision_detail=image_detail_config or ImagePromptMessageContent.DETAIL.HIGH,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
)
return list(prompt_messages)
@property
def model_instance(self) -> ModelInstance:
def model_instance(self) -> PreparedLLMProtocol:
return self._model_instance
@classmethod

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Mapping
from typing import Any, Protocol
import httpx
@@ -35,8 +35,6 @@ class ToolFileManagerProtocol(Protocol):
def create_file_by_raw(
self,
*,
user_id: str,
tenant_id: str,
conversation_id: str | None,
file_binary: bytes,
mimetype: str,
@@ -44,3 +42,7 @@ class ToolFileManagerProtocol(Protocol):
) -> Any: ...
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ...
class FileReferenceFactoryProtocol(Protocol):
def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: ...

View File

@@ -1,9 +1,9 @@
from pydantic import BaseModel, Field
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.llm import ModelConfig, VisionConfig
from dify_graph.prompt_entities import MemoryConfig
class ClassConfig(BaseModel):

View File

@@ -3,9 +3,6 @@ import re
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.model_manager import ModelInstance
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
@@ -14,7 +11,7 @@ from dify_graph.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from dify_graph.model_runtime.entities import LLMMode, LLMUsage, ModelPropertyKey, PromptMessageRole
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import ModelInvokeCompletedEvent, NodeRunResult
@@ -27,10 +24,11 @@ from dify_graph.nodes.llm import (
LLMNodeCompletionModelPromptTemplate,
llm_utils,
)
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.llm.file_saver import LLMFileSaver
from dify_graph.nodes.llm.protocols import TemplateRenderer
from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol
from dify_graph.nodes.protocols import HttpClientProtocol
from libs.json_in_md_parser import parse_and_check_json_markdown
from dify_graph.utils.json_in_md_parser import parse_and_check_json_markdown
from .entities import QuestionClassifierNodeData
from .exc import InvalidModelTypeError
@@ -49,15 +47,20 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
class _PassthroughPromptMessageSerializer:
def serialize(self, *, model_mode: Any, prompt_messages: Sequence[Any]) -> Any:
_ = model_mode
return list(prompt_messages)
class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
execution_type = NodeExecutionType.BRANCH
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_model_instance: ModelInstance
_prompt_message_serializer: PromptMessageSerializerProtocol
_model_instance: PreparedLLMProtocol
_memory: PromptMessageMemory | None
_template_renderer: TemplateRenderer
@@ -68,13 +71,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
credentials_provider: object | None = None,
model_factory: object | None = None,
model_instance: PreparedLLMProtocol,
http_client: HttpClientProtocol,
template_renderer: TemplateRenderer,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
llm_file_saver: LLMFileSaver,
prompt_message_serializer: PromptMessageSerializerProtocol | None = None,
):
super().__init__(
id=id,
@@ -85,20 +89,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
self._credentials_provider = credentials_provider
self._model_factory = model_factory
_ = credentials_provider, model_factory, http_client
self._model_instance = model_instance
self._memory = memory
self._template_renderer = template_renderer
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
http_client=http_client,
)
self._llm_file_saver = llm_file_saver
self._prompt_message_serializer = prompt_message_serializer or _PassthroughPromptMessageSerializer()
@classmethod
def version(cls):
@@ -169,7 +166,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.require_dify_context().user_id,
structured_output_enabled=False,
structured_output=None,
file_saver=self._llm_file_saver,
@@ -205,7 +201,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
category_id = category_id_result
process_data = {
"model_mode": node_data.model.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
"prompts": self._prompt_message_serializer.serialize(
model_mode=node_data.model.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
@@ -247,7 +243,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
)
@property
def model_instance(self) -> ModelInstance:
def model_instance(self) -> PreparedLLMProtocol:
return self._model_instance
@classmethod
@@ -285,7 +281,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self,
node_data: QuestionClassifierNodeData,
query: str,
model_instance: ModelInstance,
model_instance: PreparedLLMProtocol,
context: str | None,
) -> int:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
@@ -334,7 +330,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
model_mode = LLMMode(node_data.model.mode)
classes = node_data.classes
categories = []
for class_ in classes:
@@ -350,7 +346,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
)
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
if model_mode == LLMMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
)
@@ -381,7 +377,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
elif model_mode == LLMMode.COMPLETION:
return LLMNodeCompletionModelPromptTemplate(
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(
histories=memory_str,

View File

@@ -0,0 +1,75 @@
from __future__ import annotations
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Protocol
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.nodes.tool_runtime_entities import (
ToolRuntimeHandle,
ToolRuntimeMessage,
ToolRuntimeParameter,
)
if TYPE_CHECKING:
from dify_graph.nodes.human_input.entities import DeliveryChannelConfig
from dify_graph.nodes.tool.entities import ToolNodeData
from dify_graph.runtime import VariablePool
class ToolNodeRuntimeProtocol(Protocol):
"""Workflow-layer adapter owned by `core.workflow` and consumed by `dify_graph`.
The graph package depends only on these DTOs and lets the workflow layer
translate between graph-owned abstractions and `core.tools` internals.
"""
def get_runtime(
self,
*,
node_id: str,
node_data: ToolNodeData,
variable_pool: VariablePool | None,
) -> ToolRuntimeHandle: ...
def get_runtime_parameters(
self,
*,
tool_runtime: ToolRuntimeHandle,
) -> Sequence[ToolRuntimeParameter]: ...
def invoke(
self,
*,
tool_runtime: ToolRuntimeHandle,
tool_parameters: Mapping[str, Any],
workflow_call_depth: int,
conversation_id: str | None,
provider_name: str,
) -> Generator[ToolRuntimeMessage, None, None]: ...
def get_usage(
self,
*,
tool_runtime: ToolRuntimeHandle,
) -> LLMUsage: ...
def build_file_reference(self, *, mapping: Mapping[str, Any]) -> Any: ...
def resolve_provider_icons(
self,
*,
provider_name: str,
default_icon: str | None = None,
) -> tuple[str | None, str | None]: ...
class HumanInputNodeRuntimeProtocol(Protocol):
def invoke_source(self) -> str: ...
def apply_delivery_runtime(
self,
*,
methods: Sequence[DeliveryChannelConfig],
) -> Sequence[DeliveryChannelConfig]: ...
def console_actor_id(self) -> str | None: ...

View File

@@ -4,9 +4,10 @@ from typing import TYPE_CHECKING, Any
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData
from dify_graph.nodes.template_transform.template_renderer import (
from dify_graph.template_rendering import (
Jinja2TemplateRenderer,
TemplateRenderError,
)
@@ -20,7 +21,7 @@ DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM
_template_renderer: Jinja2TemplateRenderer
_jinja2_template_renderer: Jinja2TemplateRenderer
_max_output_length: int
def __init__(
@@ -30,7 +31,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
template_renderer: Jinja2TemplateRenderer,
jinja2_template_renderer: Jinja2TemplateRenderer,
max_output_length: int | None = None,
) -> None:
super().__init__(
@@ -39,7 +40,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._template_renderer = template_renderer
self._jinja2_template_renderer = jinja2_template_renderer
if max_output_length is not None and max_output_length <= 0:
raise ValueError("max_output_length must be a positive integer")
@@ -70,7 +71,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
variables[variable_name] = value.to_object() if value else None
# Run code
try:
rendered = self._template_renderer.render_template(self.node_data.template, variables)
rendered = self._jinja2_template_renderer.render_template(self.node_data.template, variables)
except TemplateRenderError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
@@ -87,9 +88,32 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: TemplateTransformNodeData | Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables
}
_ = graph_config
raw_variables = (
node_data.variables if isinstance(node_data, TemplateTransformNodeData) else node_data.get("variables", [])
)
variable_mapping: dict[str, Sequence[str]] = {}
for variable_selector in raw_variables:
if isinstance(variable_selector, VariableSelector):
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
continue
if not isinstance(variable_selector, Mapping):
continue
variable = variable_selector.get("variable")
value_selector = variable_selector.get("value_selector")
if (
isinstance(variable, str)
and isinstance(value_selector, Sequence)
and all(isinstance(selector_part, str) for selector_part in value_selector)
):
variable_mapping[node_id + "." + variable] = list(value_selector)
return variable_mapping

View File

@@ -1,13 +1,27 @@
from enum import StrEnum, auto
from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
class ToolProviderType(StrEnum):
"""
Graph-owned enum for persisted tool provider kinds.
"""
PLUGIN = auto()
BUILT_IN = "builtin"
WORKFLOW = auto()
API = auto()
APP = auto()
DATASET_RETRIEVAL = "dataset-retrieval"
MCP = auto()
class ToolEntity(BaseModel):
provider_id: str
provider_type: ToolProviderType

View File

@@ -4,6 +4,18 @@ class ToolNodeError(ValueError):
pass
class ToolRuntimeResolutionError(ToolNodeError):
"""Raised when the workflow layer cannot construct a tool runtime."""
pass
class ToolRuntimeInvocationError(ToolNodeError):
"""Raised when the workflow layer fails while invoking a tool runtime."""
pass
class ToolParameterError(ToolNodeError):
"""Exception raised for errors in tool parameters."""

View File

@@ -1,12 +1,6 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
BuiltinNodeTypes,
@@ -20,10 +14,14 @@ from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEven
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.protocols import ToolFileManagerProtocol
from dify_graph.nodes.runtime import ToolNodeRuntimeProtocol
from dify_graph.nodes.tool_runtime_entities import (
ToolRuntimeHandle,
ToolRuntimeMessage,
ToolRuntimeParameter,
)
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment
from dify_graph.variables.variables import ArrayAnyVariable
from factories import file_factory
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .entities import ToolNodeData
from .exc import (
@@ -52,6 +50,7 @@ class ToolNode(Node[ToolNodeData]):
graph_runtime_state: "GraphRuntimeState",
*,
tool_file_manager_factory: ToolFileManagerProtocol,
runtime: ToolNodeRuntimeProtocol | None = None,
):
super().__init__(
id=id,
@@ -60,6 +59,9 @@ class ToolNode(Node[ToolNodeData]):
graph_runtime_state=graph_runtime_state,
)
self._tool_file_manager_factory = tool_file_manager_factory
if runtime is None:
raise ValueError("runtime is required")
self._runtime = runtime
@classmethod
def version(cls) -> str:
@@ -73,10 +75,6 @@ class ToolNode(Node[ToolNodeData]):
"""
Run the tool node
"""
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
dify_ctx = self.require_dify_context()
# fetch tool icon
tool_info = {
"provider_type": self.node_data.provider_type.value,
@@ -86,8 +84,6 @@ class ToolNode(Node[ToolNodeData]):
# get tool runtime
try:
from core.tools.tool_manager import ToolManager
# This is an issue that caused problems before.
# Logically, we shouldn't use the node_data.version field for judgment
# But for backward compatibility with historical data
@@ -95,13 +91,10 @@ class ToolNode(Node[ToolNodeData]):
variable_pool: VariablePool | None = None
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
variable_pool = self.graph_runtime_state.variable_pool
tool_runtime = ToolManager.get_workflow_tool_runtime(
dify_ctx.tenant_id,
dify_ctx.app_id,
self._node_id,
self.node_data,
dify_ctx.invoke_from,
variable_pool,
tool_runtime = self._runtime.get_runtime(
node_id=self._node_id,
node_data=self.node_data,
variable_pool=variable_pool,
)
except ToolNodeError as e:
yield StreamCompletedEvent(
@@ -116,7 +109,7 @@ class ToolNode(Node[ToolNodeData]):
return
# get parameters
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
tool_parameters = self._runtime.get_runtime_parameters(tool_runtime=tool_runtime)
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
@@ -132,14 +125,12 @@ class ToolNode(Node[ToolNodeData]):
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = ToolEngine.generic_invoke(
tool=tool_runtime,
message_stream = self._runtime.invoke(
tool_runtime=tool_runtime,
tool_parameters=parameters,
user_id=dify_ctx.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
app_id=dify_ctx.app_id,
conversation_id=conversation_id.text if conversation_id else None,
provider_name=self.node_data.provider_name,
)
except ToolNodeError as e:
yield StreamCompletedEvent(
@@ -159,38 +150,16 @@ class ToolNode(Node[ToolNodeData]):
messages=message_stream,
tool_info=tool_info,
parameters_for_log=parameters_for_log,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
node_id=self._node_id,
tool_runtime=tool_runtime,
)
except ToolInvokeError as e:
except ToolNodeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}",
error_type=type(e).__name__,
)
)
except PluginInvokeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name),
error_type=type(e).__name__,
)
)
except PluginDaemonClientSideError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool, error: {e.description}",
error=str(e),
error_type=type(e).__name__,
)
)
@@ -198,7 +167,7 @@ class ToolNode(Node[ToolNodeData]):
def _generate_parameters(
self,
*,
tool_parameters: Sequence[ToolParameter],
tool_parameters: Sequence[ToolRuntimeParameter],
variable_pool: "VariablePool",
node_data: ToolNodeData,
for_log: bool = False,
@@ -207,7 +176,7 @@ class ToolNode(Node[ToolNodeData]):
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
tool_parameters (Sequence[ToolRuntimeParameter]): The list of tool parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
@@ -247,40 +216,29 @@ class ToolNode(Node[ToolNodeData]):
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
messages: Generator[ToolRuntimeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_id: str,
tool_runtime: Tool,
tool_runtime: ToolRuntimeHandle,
**_: Any,
) -> Generator[NodeEventBase, None, LLMUsage]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
Convert graph-owned tool runtime messages into node outputs.
"""
# transform message and handle file storage
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict | list] = []
variables: dict[str, Any] = {}
for message in message_stream:
for message in messages:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
ToolRuntimeMessage.MessageType.IMAGE_LINK,
ToolRuntimeMessage.MessageType.BINARY_LINK,
ToolRuntimeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
url = message.message.text
if message.meta:
@@ -300,14 +258,11 @@ class ToolNode(Node[ToolNodeData]):
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
file = self._runtime.build_file_reference(mapping=mapping)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
elif message.type == ToolRuntimeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
@@ -320,27 +275,22 @@ class ToolNode(Node[ToolNodeData]):
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
files.append(self._runtime.build_file_reference(mapping=mapping))
elif message.type == ToolRuntimeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
elif message.type == ToolRuntimeMessage.MessageType.JSON:
assert isinstance(message.message, ToolRuntimeMessage.JsonMessage)
# JSON message handling for tool node
if message.message.json_object:
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
elif message.type == ToolRuntimeMessage.MessageType.LINK:
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
# Check if this LINK message is a file link
file_obj = (message.meta or {}).get("file")
@@ -356,8 +306,8 @@ class ToolNode(Node[ToolNodeData]):
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
elif message.type == ToolRuntimeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolRuntimeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
@@ -374,7 +324,7 @@ class ToolNode(Node[ToolNodeData]):
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
elif message.type == ToolRuntimeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, dict)
# Validate that meta contains a 'file' key
@@ -385,38 +335,16 @@ class ToolNode(Node[ToolNodeData]):
if not isinstance(message.meta["file"], File):
raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
elif message.type == ToolRuntimeMessage.MessageType.LOG:
assert isinstance(message.message, ToolRuntimeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
icon, icon_dark = self._runtime.resolve_provider_icons(
provider_name=dict_metadata["provider"],
default_icon=icon,
)
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
@@ -446,7 +374,7 @@ class ToolNode(Node[ToolNodeData]):
is_final=True,
)
usage = self._extract_tool_usage(tool_runtime)
usage = self._runtime.get_usage(tool_runtime=tool_runtime)
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
@@ -468,21 +396,6 @@ class ToolNode(Node[ToolNodeData]):
return usage
@staticmethod
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
# Avoid importing WorkflowTool at module import time; rely on duck typing
# Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes.
latest = getattr(tool_runtime, "latest_usage", None)
# Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects
# for any name, so we must type-check here.
if isinstance(latest, LLMUsage):
return latest
if isinstance(latest, dict):
# Allow dict payloads from external runtimes
return LLMUsage.model_validate(latest)
# Fallback to empty usage when attribute is missing or not a valid payload
return LLMUsage.empty_usage()
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@@ -0,0 +1,101 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
class _ToolRuntimeModel(BaseModel):
model_config = ConfigDict(extra="forbid")
@dataclass(frozen=True, slots=True)
class ToolRuntimeHandle:
"""Opaque graph-owned handle for a workflow-layer tool runtime."""
raw: object
@dataclass(frozen=True, slots=True)
class ToolRuntimeParameter:
"""Graph-owned parameter shape used by tool nodes."""
name: str
required: bool = False
class ToolRuntimeMessage(_ToolRuntimeModel):
"""Graph-owned tool invocation message DTO."""
class TextMessage(_ToolRuntimeModel):
text: str
class JsonMessage(_ToolRuntimeModel):
json_object: dict[str, Any] | list[Any]
suppress_output: bool = Field(default=False)
class BlobMessage(_ToolRuntimeModel):
blob: bytes
class BlobChunkMessage(_ToolRuntimeModel):
id: str
sequence: int
total_length: int
blob: bytes
end: bool
class FileMessage(_ToolRuntimeModel):
file_marker: str = Field(default="file_marker")
class VariableMessage(_ToolRuntimeModel):
variable_name: str
variable_value: dict[str, Any] | list[Any] | str | int | float | bool | None
stream: bool = Field(default=False)
class LogMessage(_ToolRuntimeModel):
class LogStatus(StrEnum):
START = auto()
ERROR = auto()
SUCCESS = auto()
id: str
label: str
parent_id: str | None = None
error: str | None = None
status: LogStatus
data: dict[str, Any]
metadata: dict[str, Any] = Field(default_factory=dict)
class RetrieverResourceMessage(_ToolRuntimeModel):
retriever_resources: list[dict[str, Any]]
context: str
class MessageType(StrEnum):
TEXT = auto()
IMAGE = auto()
LINK = auto()
BLOB = auto()
JSON = auto()
IMAGE_LINK = auto()
BINARY_LINK = auto()
VARIABLE = auto()
FILE = auto()
LOG = auto()
BLOB_CHUNK = auto()
RETRIEVER_RESOURCES = auto()
type: MessageType = MessageType.TEXT
message: (
JsonMessage
| TextMessage
| BlobChunkMessage
| BlobMessage
| LogMessage
| FileMessage
| None
| VariableMessage
| RetrieverResourceMessage
)
meta: dict[str, Any] | None = None

View File

@@ -0,0 +1,47 @@
from typing import Literal
from pydantic import BaseModel
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
class ChatModelMessage(BaseModel):
"""Graph-owned chat prompt template message."""
text: str
role: PromptMessageRole
edition_type: Literal["basic", "jinja2"] | None = None
class CompletionModelPromptTemplate(BaseModel):
"""Graph-owned completion prompt template."""
text: str
edition_type: Literal["basic", "jinja2"] | None = None
class MemoryConfig(BaseModel):
"""Graph-owned memory configuration for prompt assembly."""
class RolePrefix(BaseModel):
"""Role labels used when serializing completion-model histories."""
user: str
assistant: str
class WindowConfig(BaseModel):
"""History windowing controls."""
enabled: bool
size: int | None = None
role_prefix: RolePrefix | None = None
window: WindowConfig
query_prompt_template: str | None = None
__all__ = [
"ChatModelMessage",
"CompletionModelPromptTemplate",
"MemoryConfig",
]

View File

@@ -18,9 +18,6 @@ class FormNotFoundError(HumanInputError):
@dataclasses.dataclass
class FormCreateParams:
# app_id is the identifier for the app that the form belongs to.
# It is a string with uuid format.
app_id: str
# None when creating a delivery test form; set for runtime forms.
workflow_execution_id: str | None
@@ -45,6 +42,9 @@ class FormCreateParams:
resolved_default_values: Mapping[str, Any]
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
# Optional application identifier. Implementations may bind this at construction time.
app_id: str | None = None
# Force creating a console-only recipient for submission in Console.
console_recipient_required: bool = False
console_creator_account_id: str | None = None

View File

@@ -8,19 +8,17 @@ from dify_graph.nodes.code.entities import CodeLanguage
class TemplateRenderError(ValueError):
"""Raised when rendering a Jinja2 template fails."""
"""Raised when rendering a template fails."""
class Jinja2TemplateRenderer(Protocol):
"""Render Jinja2 templates for template transform nodes."""
"""Shared contract for rendering Jinja2 templates in graph nodes."""
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
"""Render a Jinja2 template with provided variables."""
raise NotImplementedError
def render_template(self, template: str, variables: Mapping[str, Any]) -> str: ...
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
"""Adapter that renders Jinja2 templates via CodeExecutor."""
"""Adapter that renders Jinja2 templates via the workflow code executor."""
_code_executor: WorkflowCodeExecutor

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
import abc
import datetime
from typing import Protocol
class _NowFunction(Protocol):
@abc.abstractmethod
def __call__(self, tz: datetime.timezone | None) -> datetime.datetime:
"""Return the current time for the requested timezone."""
...
_now_func: _NowFunction = datetime.datetime.now
def naive_utc_now() -> datetime.datetime:
"""Return the current UTC time as a naive datetime."""
return _now_func(datetime.UTC).replace(tzinfo=None)

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
import json
class OutputParserError(ValueError):
"""Raised when a markdown-wrapped JSON payload cannot be parsed or validated."""
def parse_json_markdown(json_string: str) -> dict | list:
"""Extract and parse the first JSON object or array embedded in markdown text."""
json_string = json_string.strip()
starts = ["```json", "```", "``", "`", "{", "["]
ends = ["```", "``", "`", "}", "]"]
end_index = -1
start_index = 0
for start_marker in starts:
start_index = json_string.find(start_marker)
if start_index != -1:
if json_string[start_index] not in ("{", "["):
start_index += len(start_marker)
break
if start_index != -1:
for end_marker in ends:
end_index = json_string.rfind(end_marker, start_index)
if end_index != -1:
if json_string[end_index] in ("}", "]"):
end_index += 1
break
if start_index == -1 or end_index == -1 or start_index >= end_index:
raise ValueError("could not find json block in the output.")
extracted_content = json_string[start_index:end_index].strip()
return json.loads(extracted_content)
def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
try:
json_obj = parse_json_markdown(text)
except json.JSONDecodeError as exc:
raise OutputParserError(f"got invalid json object. error: {exc}") from exc
if isinstance(json_obj, list):
if len(json_obj) == 1 and isinstance(json_obj[0], dict):
json_obj = json_obj[0]
else:
raise OutputParserError(f"got invalid return object. obj:{json_obj}")
for key in expected_keys:
if key not in json_obj:
raise OutputParserError(
f"got invalid return object. expected key `{key}` to be present, but got {json_obj}"
)
return json_obj

View File

@@ -92,7 +92,7 @@ class AppService:
default_model_config = default_model_config.copy() if default_model_config else None
if default_model_config and "model" in default_model_config:
# get model provider
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=account.current_tenant_id or "")
# get default model instance
try:

View File

@@ -61,7 +61,7 @@ class AudioService:
message = f"Audio size larger than {FILE_SIZE} mb"
raise AudioTooLargeServiceError(message)
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user)
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT
)
@@ -71,7 +71,7 @@ class AudioService:
buffer = io.BytesIO(file_content)
buffer.name = "temp.mp3"
return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)}
return {"text": model_instance.invoke_speech2text(file=buffer)}
@classmethod
def transcript_tts(
@@ -109,7 +109,7 @@ class AudioService:
voice = cast(str | None, text_to_speech_dict.get("voice"))
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user)
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
)
@@ -123,9 +123,7 @@ class AudioService:
else:
raise ValueError("Sorry, no voice available.")
return model_instance.invoke_tts(
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
)
return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice)
except Exception as e:
raise e
@@ -155,7 +153,7 @@ class AudioService:
@classmethod
def transcript_tts_voices(cls, tenant_id: str, language: str):
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS)
if model_instance is None:
raise ProviderNotSupportTextToSpeechServiceError()

View File

@@ -228,7 +228,7 @@ class DatasetService:
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
embedding_model = None
if indexing_technique == "high_quality":
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
if embedding_model_provider and embedding_model_name:
# check if embedding model setting is valid
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model_name)
@@ -350,7 +350,7 @@ class DatasetService:
def check_dataset_model_setting(dataset):
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id)
model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
@@ -367,7 +367,7 @@ class DatasetService:
@staticmethod
def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str):
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_manager.get_model_instance(
tenant_id=tenant_id,
provider=embedding_model_provider,
@@ -384,7 +384,7 @@ class DatasetService:
@staticmethod
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str):
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=model_provider,
@@ -405,7 +405,7 @@ class DatasetService:
@staticmethod
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_manager.get_model_instance(
tenant_id=tenant_id,
provider=reranking_model_provider,
@@ -742,7 +742,7 @@ class DatasetService:
"""
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
embedding_model = model_manager.get_model_instance(
@@ -860,7 +860,7 @@ class DatasetService:
"""
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
try:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
@@ -954,7 +954,7 @@ class DatasetService:
dataset.chunk_structure = knowledge_configuration.chunk_structure
dataset.indexing_technique = knowledge_configuration.indexing_technique
if knowledge_configuration.indexing_technique == "high_quality":
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, # ignore type error
provider=knowledge_configuration.embedding_model_provider or "",
@@ -996,7 +996,7 @@ class DatasetService:
action = "add"
# get embedding model setting
try:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=knowledge_configuration.embedding_model_provider,
@@ -1049,7 +1049,7 @@ class DatasetService:
or knowledge_configuration.embedding_model != dataset.embedding_model
):
action = "update"
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
embedding_model = None
try:
embedding_model = model_manager.get_model_instance(
@@ -1908,7 +1908,7 @@ class DocumentService:
dataset.indexing_technique = knowledge_config.indexing_technique
if knowledge_config.indexing_technique == "high_quality":
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
dataset_embedding_model = knowledge_config.embedding_model
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
@@ -2220,7 +2220,7 @@ class DocumentService:
# dataset.indexing_technique = knowledge_config.indexing_technique
# if knowledge_config.indexing_technique == "high_quality":
# model_manager = ModelManager()
# model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
# if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
# dataset_embedding_model = knowledge_config.embedding_model
# dataset_embedding_model_provider = knowledge_config.embedding_model_provider
@@ -3125,7 +3125,7 @@ class SegmentService:
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -3208,7 +3208,7 @@ class SegmentService:
with redis_client.lock(lock_name, timeout=600):
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -3346,7 +3346,7 @@ class SegmentService:
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id)
if dataset.embedding_model_provider:
embedding_model_instance = model_manager.get_model_instance(
@@ -3409,7 +3409,7 @@ class SegmentService:
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -3450,7 +3450,7 @@ class SegmentService:
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id)
if dataset.embedding_model_provider:
embedding_model_instance = model_manager.get_model_instance(

View File

@@ -255,7 +255,7 @@ class MessageService:
app_model=app_model, conversation_id=message.conversation_id, user=user
)
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id)
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow_service = WorkflowService()

Some files were not shown because too many files have changed in this diff Show More